This notebook implements multiple standard CSP pipelines and tests their performance on the data from the database provided by Kaya et al.. The knowledge and utilities obtained from the experimental notebooks four to five are used throughout this notebook.
This notebook works in an offline fashion and uses epochs with a length of 3 seconds. This epoch starts 1 second before the visual queue was given, includes the 1 second the visual queue was shown and ends 1 second after the visual queue was hidden, totalling 3 seconds. Baseline correction was done on the first second of the epoch, meaning the second before the visual queue was shown. The effective training and testing are done in a half-second window, starting 0.1 seconds after the start of the visual queue. A window of 0.5 seconds was chosen as it is a common size for sliding window approaches in online systems.
Instructions on where to get the data are available on the GitHub repository of the BCI master thesis project. These instructions are under bci-master-thesis/code/data/CLA/README.md. We will use the utility file bci-master-thesis/code/utils/CLA_dataset.py to work with this data. The data was stored as FIF files, which are included in the GitHub repository of the BCI master thesis project.
The bci-master-thesis Anaconda environment should be active to ensure proper support. Installation instructions are available on the GitHub repository of the BCI master thesis project.
####################################################
# CHECKING FOR RIGHT ANACONDA ENVIRONMENT
####################################################
import os
from platform import python_version
from pathlib import Path
from copy import copy
print(f"Active environment: {os.environ['CONDA_DEFAULT_ENV']}")
print(f"Correct environment: {os.environ['CONDA_DEFAULT_ENV'] == 'bci-master-thesis'}")
print(f"\nPython version: {python_version()}")
print(f"Correct Python version: {python_version() == '3.8.10'}")
Active environment: bci-master-thesis Correct environment: True Python version: 3.8.10 Correct Python version: True
####################################################
# LOADING MODULES
####################################################
# Load util function file
import sys
sys.path.append('../utils')
import CLA_dataset
# IO functions
from IPython.utils import io
import copy
# Set logging level for MNE before loading MNE
os.environ['MNE_LOGGING_LEVEL'] = 'WARNING'
# Modules tailored for EEG data
import mne; print(f"MNE version (1.0.2 recommended): {mne.__version__}")
from mne.decoding import CSP
# ML libraries
import sklearn; print(f"Scikit-learn version (1.0.2 recommended): {sklearn.__version__}")
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import train_test_split, StratifiedKFold, GridSearchCV
from sklearn.pipeline import Pipeline
from sklearn.metrics import ConfusionMatrixDisplay, accuracy_score
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
# Data manipulation modules
import numpy as np; print(f"Numpy version (1.21.5 recommended): {np.__version__}")
import pandas as pd; print(f"Pandas version (1.4.1 recommended): {pd.__version__}")
# Plotting
import matplotlib; print(f"Matplotlib version (3.5.1 recommended): {matplotlib.__version__}")
import matplotlib.pyplot as plt
# Storing files
import pickle; print(f"Pickle version (4.0 recommended): {pickle.format_version}")
MNE version (1.0.2 recommended): 1.0.2 Scikit-learn version (1.0.2 recommended): 1.0.2 Numpy version (1.21.5 recommended): 1.21.5 Pandas version (1.4.1 recommended): 1.4.1 Matplotlib version (3.5.1 recommended): 3.5.1 Pickle version (4.0 recommended): 4.0
As mentioned, this notebook uses a database provided by Kaya et al. The CLA dataset in particular. Instructions on where to get the data are available on the GitHub repository of the BCI master thesis project. These instructions are under bci-master-thesis/code/data/CLA/README.md. The following code block checks if all required files are available.
####################################################
# CHECKING FILE ACCESS
####################################################
# Use util to determine if we have access
print("Full Matlab CLA file access: " + str(CLA_dataset.check_matlab_files_availability()))
print("Full MNE CLA file access: " + str(CLA_dataset.check_mne_files_availability()))
Full Matlab CLA file access: True Full MNE CLA file access: True
As discussed in the master's thesis, training and testing a classification system can happen using multiple strategies. A classifier may be trained on a singular subject, using a singular session and testing on that same session. This is an over-optimistic testing scenario and has a great risk of overfitting with poor generalisation to new sessions or new subjects but can be an okay baseline test to see if at least something can be learned. We do this for three different traditional machine learning classifiers: linear discriminant analysis (LDA), support vector machines (SVM) and random forest (RF). K-nearest neighbour (KNN) is not considered as it is too time-consuming in predictions and complex models such as a multilayer perceptron (MLP) are not considered either as they are an integral part of the deep learning models considered in later notebooks.
This experiment works as follows:
####################################################
# GRID SEARCHING BEST PIPELINE FOR EACH SUBJECT
####################################################
# Configure global parameters for all experiments
subject_ids_to_test = ["B", "C", "E"] # Subjects with three recordings
start_offset = -1 # One second before visual queue
end_offset = 1 # One second after visual queue
baseline = (None, 0) # Baseline correction using data before the visual queue
filter_lower_bound = 2 # Filter out any frequency below this
filter_upper_bound = 32 # Filter out any frequency above this
do_experiment = False # Long experiment disabled per default
if do_experiment:
# Loop over all subjects and perform the grid search for finding the best parameters
for subject_id in subject_ids_to_test:
# Get MNE raw object for latest recording of that subject
mne_raw = CLA_dataset.get_last_raw_mne_data_for_subject(subject_id= subject_id)
# Get epochs for that MNE raw
mne_epochs = CLA_dataset.get_usefull_epochs_from_raw(mne_raw,
start_offset= start_offset,
end_offset= end_offset,
baseline= baseline)
# Only keep epochs from the MI tasks
mne_epochs = mne_epochs['task/neutral', 'task/left', 'task/right']
# Load epochs into memory
mne_epochs.load_data()
# Get the labels
labels = mne_epochs.events[:, -1]
# Use a fixed filter
mne_epochs.filter(l_freq= filter_lower_bound,
h_freq= filter_upper_bound,
picks= "all",
phase= "minimum",
fir_window= "blackman",
fir_design= "firwin",
pad= 'median',
n_jobs= -1,
verbose= False)
# Get a half second window
mne_epochs_data = mne_epochs.get_data(tmin= 0.1, tmax= 0.6)
# Create a test and train split
X_train, X_test, y_train, y_test = train_test_split(mne_epochs_data,
labels,
test_size = 0.2,
shuffle= True,
stratify= labels,
random_state= 1998)
# Configure the pipeline components by specifying the default parameters
csp = CSP(norm_trace=False,
component_order="mutual_info",
cov_est= "epoch")
lda = LinearDiscriminantAnalysis(shrinkage= None,
priors=[1/3, 1/3, 1/3])
# Configure the pipeline
pipeline = Pipeline([('CSP', csp), ('LDA', lda)])
# Configure cross validation to use
cv = StratifiedKFold(n_splits=4,
shuffle= True,
random_state= 2022)
# Configure the hyperparameters to test
# NOTE: these are somewhat limited due to limitedd computational resources
param_grid = [{"CSP__n_components": [2, 3, 4, 6, 10],
"LDA__solver": ["svd"],
"LDA__tol": [0.0001, 0.00001, 0.001, 0.0004, 0.00007]
},
{"CSP__n_components": [2, 3, 4, 6, 10],
"LDA__solver": ["lsqr" , "eigen"]
}]
# Configure the grid search
grid_search = GridSearchCV(estimator= pipeline,
param_grid= param_grid,
scoring= "accuracy",
n_jobs= -1,
refit= False, # We will do this manually
cv= cv,
verbose= 10,
return_train_score= True)
# Do the grid search on the training data
grid_search.fit(X= X_train,
y= y_train)
# Store the results of the grid search
with open(f"saved_variables/2/samesubject_samesession/subject{subject_id}/gridsearch_csplda_subject{subject_id}.pickle", 'wb') as file:
pickle.dump(grid_search, file)
# Store the train and test data so the best model can be retrained later
with open(f"saved_variables/2/samesubject_samesession/subject{subject_id}/testdata-x_csplda_subject{subject_id}.pickle", 'wb') as file:
pickle.dump(X_test, file)
with open(f"saved_variables/2/samesubject_samesession/subject{subject_id}/testdata-y_svm_subject{subject_id}.pickle", 'wb') as file:
pickle.dump(y_test, file)
with open(f"saved_variables/2/samesubject_samesession/subject{subject_id}/traindata-x_svm_subject{subject_id}.pickle", 'wb') as file:
pickle.dump(X_train, file)
with open(f"saved_variables/2/samesubject_samesession/subject{subject_id}/traindata-y_csplda_subject{subject_id}.pickle", 'wb') as file:
pickle.dump(y_train, file)
# Delete vars after singular experiment
del mne_raw
del mne_epochs
del mne_epochs_data
del csp
del lda
del pipeline
del labels
del cv
del file
del X_train
del X_test
del y_train
del y_test
del grid_search
del param_grid
# Delete vars after all experiments
del subject_id
# Del global vars
del subject_ids_to_test
del filter_lower_bound
del filter_upper_bound
del baseline
del do_experiment
del end_offset
del start_offset
| Subject | CSP + LDA: cross validation accuracy | CSP + LDA: test split accuracy | Config |
|---|---|---|---|
| B | 0.6615 +- 0.0504 | 0.6094 | 6 CSP components | LDA SVD solver with 0.0001 tol |
| C | 0.7144 +- 0.0341 | 0.7240 | 10 CSP components | LDA SVD solver with 0.0001 tol |
| E | 0.7342 +- 0.0171 | 0.7277 | 10 CSP components | LDA SVD solver with 0.0001 tol |
####################################################
# GRID SEARCH RESULTS
####################################################
# Configure global parameters for all experiments
subject_ids_to_test = ["B", "C", "E"] # Subjects with three recordings
# Loop over all found results
for subject_id in subject_ids_to_test:
print("\n\n")
print("####################################################")
print(f"# GRID SEARCH RESULTS FOR SUBJECT {subject_id}")
print("####################################################")
print("\n\n")
# Open from file
with open(f"saved_variables/2/samesubject_samesession/subject{subject_id}/gridsearch_csplda_subject{subject_id}.pickle", 'rb') as f:
grid_search = pickle.load(f)
# Print the results
print(f"Best estimator has accuracy of {np.round(grid_search.best_score_, 4)} +- {np.round(grid_search.cv_results_['std_test_score'][grid_search.best_index_], 4)} with the following parameters")
print(grid_search.best_params_)
# Get grid search results
grid_search_results = pd.DataFrame(grid_search.cv_results_)
# Keep relevant columns and sort on rank
grid_search_results.drop(labels='params', axis=1, inplace= True)
grid_search_results.sort_values(by=['rank_test_score'], inplace=True)
# Display grid search resulst
print("\n\n Top 10 grid search results: ")
display(grid_search_results.head(10))
print("\n\n Worst 10 grid search results: ")
display(grid_search_results.tail(10))
# Display some statistics
print(f"\n\nIn total there are {len(grid_search_results)} different configurations tested.")
max_score = grid_search_results['mean_test_score'].max()
print(f"The best mean test score is {round(max_score, 4)}")
shared_first_place_count = len(grid_search_results[grid_search_results['mean_test_score'].between(max_score, max_score)])
print(f"There are {shared_first_place_count} configurations with this maximum score")
close_first_place_count = len(grid_search_results[grid_search_results['mean_test_score'].between(max_score-0.02, max_score)])
print(f"There are {close_first_place_count} configurations within 0.02 of this maximum score")
# Display statistics for best classifiers
print("\n\nThe describe of the configurations within 0.02 of this maximum score is as follows:")
display(grid_search_results[grid_search_results['mean_test_score'].between(max_score-0.02, max_score)].describe(include="all"))
# Remove unsused variables
del f
del grid_search
del max_score
del shared_first_place_count
del close_first_place_count
del grid_search_results
del subject_ids_to_test
del subject_id
####################################################
# GRID SEARCH RESULTS FOR SUBJECT B
####################################################
Best estimator has accuracy of 0.6615 +- 0.0504 with the following parameters
{'CSP__n_components': 6, 'LDA__solver': 'svd', 'LDA__tol': 0.0001}
Top 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_LDA__solver | param_LDA__tol | split0_test_score | split1_test_score | split2_test_score | split3_test_score | mean_test_score | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 17 | 2.638366 | 0.048087 | 0.002499 | 8.656529e-04 | 6 | svd | 0.001 | 0.578125 | 0.708333 | 0.666667 | 0.692708 | 0.661458 | 0.050362 | 1 | 0.722222 | 0.684028 | 0.689236 | 0.696181 | 0.697917 | 0.014680 |
| 15 | 2.620203 | 0.050450 | 0.010247 | 9.281167e-03 | 6 | svd | 0.0001 | 0.578125 | 0.708333 | 0.666667 | 0.692708 | 0.661458 | 0.050362 | 1 | 0.722222 | 0.684028 | 0.689236 | 0.696181 | 0.697917 | 0.014680 |
| 19 | 2.575015 | 0.037268 | 0.003250 | 4.326721e-04 | 6 | svd | 0.00007 | 0.578125 | 0.708333 | 0.666667 | 0.692708 | 0.661458 | 0.050362 | 1 | 0.722222 | 0.684028 | 0.689236 | 0.696181 | 0.697917 | 0.014680 |
| 18 | 2.600368 | 0.093760 | 0.003749 | 2.486504e-03 | 6 | svd | 0.0004 | 0.578125 | 0.708333 | 0.666667 | 0.692708 | 0.661458 | 0.050362 | 1 | 0.722222 | 0.684028 | 0.689236 | 0.696181 | 0.697917 | 0.014680 |
| 31 | 3.828332 | 0.061491 | 0.002750 | 8.287726e-04 | 6 | lsqr | NaN | 0.578125 | 0.708333 | 0.666667 | 0.692708 | 0.661458 | 0.050362 | 1 | 0.720486 | 0.682292 | 0.689236 | 0.696181 | 0.697049 | 0.014395 |
| 32 | 2.727811 | 0.112638 | 0.003249 | 4.323963e-04 | 6 | eigen | NaN | 0.578125 | 0.708333 | 0.666667 | 0.692708 | 0.661458 | 0.050362 | 1 | 0.720486 | 0.682292 | 0.689236 | 0.696181 | 0.697049 | 0.014395 |
| 16 | 2.634050 | 0.059822 | 0.002499 | 8.656874e-04 | 6 | svd | 0.00001 | 0.578125 | 0.708333 | 0.666667 | 0.692708 | 0.661458 | 0.050362 | 1 | 0.722222 | 0.684028 | 0.689236 | 0.696181 | 0.697917 | 0.014680 |
| 29 | 2.634230 | 0.069445 | 0.002499 | 8.656185e-04 | 4 | lsqr | NaN | 0.614583 | 0.692708 | 0.671875 | 0.661458 | 0.660156 | 0.028616 | 8 | 0.697917 | 0.666667 | 0.694444 | 0.684028 | 0.685764 | 0.012153 |
| 10 | 2.596897 | 0.081477 | 0.002000 | 1.685874e-07 | 4 | svd | 0.0001 | 0.614583 | 0.692708 | 0.671875 | 0.661458 | 0.660156 | 0.028616 | 8 | 0.699653 | 0.666667 | 0.694444 | 0.684028 | 0.686198 | 0.012602 |
| 11 | 2.542957 | 0.050572 | 0.001500 | 4.998446e-04 | 4 | svd | 0.00001 | 0.614583 | 0.692708 | 0.671875 | 0.661458 | 0.660156 | 0.028616 | 8 | 0.699653 | 0.666667 | 0.694444 | 0.684028 | 0.686198 | 0.012602 |
Worst 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_LDA__solver | param_LDA__tol | split0_test_score | split1_test_score | split2_test_score | split3_test_score | mean_test_score | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 22 | 2.560733 | 0.046330 | 0.006748 | 0.004814 | 10 | svd | 0.001 | 0.598958 | 0.656250 | 0.671875 | 0.651042 | 0.644531 | 0.027406 | 24 | 0.729167 | 0.750000 | 0.715278 | 0.718750 | 0.728299 | 0.013532 |
| 23 | 2.698398 | 0.044207 | 0.003499 | 0.000500 | 10 | svd | 0.0004 | 0.598958 | 0.656250 | 0.671875 | 0.651042 | 0.644531 | 0.027406 | 24 | 0.729167 | 0.750000 | 0.715278 | 0.718750 | 0.728299 | 0.013532 |
| 24 | 2.619778 | 0.046420 | 0.005000 | 0.003464 | 10 | svd | 0.00007 | 0.598958 | 0.656250 | 0.671875 | 0.651042 | 0.644531 | 0.027406 | 24 | 0.729167 | 0.750000 | 0.715278 | 0.718750 | 0.728299 | 0.013532 |
| 1 | 2.690685 | 0.016161 | 0.002749 | 0.001920 | 2 | svd | 0.00001 | 0.609375 | 0.505208 | 0.526042 | 0.651042 | 0.572917 | 0.059612 | 29 | 0.664931 | 0.522569 | 0.541667 | 0.661458 | 0.597656 | 0.065897 |
| 2 | 3.448836 | 0.068553 | 0.001750 | 0.000433 | 2 | svd | 0.001 | 0.609375 | 0.505208 | 0.526042 | 0.651042 | 0.572917 | 0.059612 | 29 | 0.664931 | 0.522569 | 0.541667 | 0.661458 | 0.597656 | 0.065897 |
| 26 | 2.667278 | 0.062282 | 0.002250 | 0.000433 | 2 | eigen | NaN | 0.609375 | 0.505208 | 0.526042 | 0.651042 | 0.572917 | 0.059612 | 29 | 0.664931 | 0.522569 | 0.541667 | 0.661458 | 0.597656 | 0.065897 |
| 4 | 2.847526 | 0.058999 | 0.001750 | 0.000432 | 2 | svd | 0.00007 | 0.609375 | 0.505208 | 0.526042 | 0.651042 | 0.572917 | 0.059612 | 29 | 0.664931 | 0.522569 | 0.541667 | 0.661458 | 0.597656 | 0.065897 |
| 25 | 2.632682 | 0.062124 | 0.001750 | 0.000433 | 2 | lsqr | NaN | 0.609375 | 0.505208 | 0.526042 | 0.651042 | 0.572917 | 0.059612 | 29 | 0.664931 | 0.522569 | 0.541667 | 0.661458 | 0.597656 | 0.065897 |
| 3 | 2.971258 | 0.054610 | 0.001750 | 0.000829 | 2 | svd | 0.0004 | 0.609375 | 0.505208 | 0.526042 | 0.651042 | 0.572917 | 0.059612 | 29 | 0.664931 | 0.522569 | 0.541667 | 0.661458 | 0.597656 | 0.065897 |
| 0 | 2.860130 | 0.053891 | 0.002000 | 0.001224 | 2 | svd | 0.0001 | 0.609375 | 0.505208 | 0.526042 | 0.651042 | 0.572917 | 0.059612 | 29 | 0.664931 | 0.522569 | 0.541667 | 0.661458 | 0.597656 | 0.065897 |
In total there are 35 different configurations tested. The best mean test score is 0.6615 There are 7 configurations with this maximum score There are 28 configurations within 0.02 of this maximum score The describe of the configurations within 0.02 of this maximum score is as follows:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_LDA__solver | param_LDA__tol | split0_test_score | split1_test_score | split2_test_score | split3_test_score | mean_test_score | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 28.000000 | 28.000000 | 28.000000 | 2.800000e+01 | 28.0 | 28 | 20.000 | 28.000000 | 28.000000 | 28.000000 | 28.000000 | 28.000000 | 28.000000 | 28.000000 | 28.000000 | 28.000000 | 28.000000 | 28.000000 | 28.000000 | 28.000000 |
| unique | NaN | NaN | NaN | NaN | 4.0 | 3 | 5.000 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| top | NaN | NaN | NaN | NaN | 6.0 | svd | 0.001 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| freq | NaN | NaN | NaN | NaN | 7.0 | 20 | 4.000 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| mean | 2.661364 | 0.060151 | 0.003457 | 1.805770e-03 | NaN | NaN | NaN | 0.595424 | 0.685268 | 0.664435 | 0.667969 | 0.653274 | 0.035462 | 12.071429 | 0.706225 | 0.691716 | 0.694134 | 0.690538 | 0.695654 | 0.011574 |
| std | 0.236873 | 0.020000 | 0.002313 | 2.365813e-03 | NaN | NaN | NaN | 0.013600 | 0.018775 | 0.011289 | 0.015633 | 0.007722 | 0.009210 | 8.519818 | 0.021296 | 0.034980 | 0.014254 | 0.020480 | 0.021547 | 0.003480 |
| min | 2.469499 | 0.022866 | 0.001500 | 1.685874e-07 | NaN | NaN | NaN | 0.578125 | 0.656250 | 0.645833 | 0.651042 | 0.644531 | 0.027406 | 1.000000 | 0.675347 | 0.666667 | 0.677083 | 0.663194 | 0.670573 | 0.005807 |
| 25% | 2.584624 | 0.046398 | 0.001937 | 4.996064e-04 | NaN | NaN | NaN | 0.585938 | 0.677083 | 0.661458 | 0.658854 | 0.645833 | 0.028616 | 6.250000 | 0.692274 | 0.666667 | 0.686198 | 0.678819 | 0.681966 | 0.010566 |
| 50% | 2.611223 | 0.058192 | 0.002999 | 8.656357e-04 | NaN | NaN | NaN | 0.596354 | 0.687500 | 0.669271 | 0.664062 | 0.653646 | 0.031372 | 11.500000 | 0.710069 | 0.674479 | 0.691840 | 0.690104 | 0.691623 | 0.012869 |
| 75% | 2.674828 | 0.068168 | 0.003811 | 2.166935e-03 | NaN | NaN | NaN | 0.602865 | 0.696615 | 0.671875 | 0.673177 | 0.660482 | 0.039228 | 19.000000 | 0.723524 | 0.700521 | 0.699653 | 0.701823 | 0.705512 | 0.013748 |
| max | 3.828332 | 0.112638 | 0.010247 | 9.281167e-03 | NaN | NaN | NaN | 0.614583 | 0.708333 | 0.677083 | 0.692708 | 0.661458 | 0.050362 | 24.000000 | 0.729167 | 0.750000 | 0.717014 | 0.718750 | 0.728299 | 0.014680 |
####################################################
# GRID SEARCH RESULTS FOR SUBJECT C
####################################################
Best estimator has accuracy of 0.7144 +- 0.0341 with the following parameters
{'CSP__n_components': 10, 'LDA__solver': 'svd', 'LDA__tol': 0.0001}
Top 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_LDA__solver | param_LDA__tol | split0_test_score | split1_test_score | split2_test_score | split3_test_score | mean_test_score | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 34 | 2.915794 | 0.092838 | 0.003749 | 0.000829 | 10 | eigen | NaN | 0.755208 | 0.671875 | 0.739583 | 0.691099 | 0.714442 | 0.034098 | 1 | 0.772174 | 0.718261 | 0.753043 | 0.756944 | 0.750106 | 0.019726 |
| 24 | 2.820352 | 0.114471 | 0.010747 | 0.012271 | 10 | svd | 0.00007 | 0.755208 | 0.671875 | 0.739583 | 0.691099 | 0.714442 | 0.034098 | 1 | 0.772174 | 0.718261 | 0.751304 | 0.756944 | 0.749671 | 0.019676 |
| 23 | 2.757913 | 0.079570 | 0.013246 | 0.017748 | 10 | svd | 0.0004 | 0.755208 | 0.671875 | 0.739583 | 0.691099 | 0.714442 | 0.034098 | 1 | 0.772174 | 0.718261 | 0.751304 | 0.756944 | 0.749671 | 0.019676 |
| 22 | 2.836415 | 0.134609 | 0.004249 | 0.000829 | 10 | svd | 0.001 | 0.755208 | 0.671875 | 0.739583 | 0.691099 | 0.714442 | 0.034098 | 1 | 0.772174 | 0.718261 | 0.751304 | 0.756944 | 0.749671 | 0.019676 |
| 21 | 2.743798 | 0.063981 | 0.004749 | 0.001920 | 10 | svd | 0.00001 | 0.755208 | 0.671875 | 0.739583 | 0.691099 | 0.714442 | 0.034098 | 1 | 0.772174 | 0.718261 | 0.751304 | 0.756944 | 0.749671 | 0.019676 |
| 20 | 2.835930 | 0.074028 | 0.004249 | 0.000433 | 10 | svd | 0.0001 | 0.755208 | 0.671875 | 0.739583 | 0.691099 | 0.714442 | 0.034098 | 1 | 0.772174 | 0.718261 | 0.751304 | 0.756944 | 0.749671 | 0.019676 |
| 33 | 2.963102 | 0.101336 | 0.003749 | 0.000828 | 10 | lsqr | NaN | 0.755208 | 0.671875 | 0.739583 | 0.691099 | 0.714442 | 0.034098 | 1 | 0.772174 | 0.718261 | 0.753043 | 0.756944 | 0.750106 | 0.019726 |
| 19 | 2.729544 | 0.103255 | 0.002250 | 0.000433 | 6 | svd | 0.00007 | 0.744792 | 0.656250 | 0.703125 | 0.732984 | 0.709288 | 0.034180 | 8 | 0.766957 | 0.711304 | 0.749565 | 0.744792 | 0.743154 | 0.020154 |
| 18 | 2.758624 | 0.100434 | 0.002749 | 0.000829 | 6 | svd | 0.0004 | 0.744792 | 0.656250 | 0.703125 | 0.732984 | 0.709288 | 0.034180 | 8 | 0.766957 | 0.711304 | 0.749565 | 0.744792 | 0.743154 | 0.020154 |
| 16 | 2.737050 | 0.137008 | 0.002250 | 0.000432 | 6 | svd | 0.00001 | 0.744792 | 0.656250 | 0.703125 | 0.732984 | 0.709288 | 0.034180 | 8 | 0.766957 | 0.711304 | 0.749565 | 0.744792 | 0.743154 | 0.020154 |
Worst 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_LDA__solver | param_LDA__tol | split0_test_score | split1_test_score | split2_test_score | split3_test_score | mean_test_score | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 5 | 2.657662 | 0.055349 | 0.001500 | 0.000499 | 3 | svd | 0.0001 | 0.364583 | 0.500000 | 0.390625 | 0.397906 | 0.413279 | 0.051578 | 22 | 0.384348 | 0.457391 | 0.420870 | 0.409722 | 0.418083 | 0.026272 |
| 28 | 2.869725 | 0.119720 | 0.001499 | 0.000500 | 3 | eigen | NaN | 0.364583 | 0.500000 | 0.390625 | 0.397906 | 0.413279 | 0.051578 | 22 | 0.382609 | 0.457391 | 0.420870 | 0.409722 | 0.417648 | 0.026835 |
| 27 | 2.874114 | 0.081583 | 0.001749 | 0.000433 | 3 | lsqr | NaN | 0.364583 | 0.500000 | 0.390625 | 0.397906 | 0.413279 | 0.051578 | 22 | 0.382609 | 0.457391 | 0.420870 | 0.409722 | 0.417648 | 0.026835 |
| 2 | 2.873989 | 0.139466 | 0.001749 | 0.000433 | 2 | svd | 0.001 | 0.364583 | 0.416667 | 0.385417 | 0.387435 | 0.388525 | 0.018547 | 29 | 0.353043 | 0.373913 | 0.424348 | 0.401042 | 0.388087 | 0.026980 |
| 25 | 2.965988 | 0.100436 | 0.001750 | 0.000432 | 2 | lsqr | NaN | 0.364583 | 0.416667 | 0.385417 | 0.387435 | 0.388525 | 0.018547 | 29 | 0.351304 | 0.373913 | 0.424348 | 0.399306 | 0.387218 | 0.027348 |
| 1 | 3.268848 | 0.095755 | 0.001999 | 0.000707 | 2 | svd | 0.00001 | 0.364583 | 0.416667 | 0.385417 | 0.387435 | 0.388525 | 0.018547 | 29 | 0.353043 | 0.373913 | 0.424348 | 0.401042 | 0.388087 | 0.026980 |
| 3 | 2.706552 | 0.076533 | 0.004249 | 0.004491 | 2 | svd | 0.0004 | 0.364583 | 0.416667 | 0.385417 | 0.387435 | 0.388525 | 0.018547 | 29 | 0.353043 | 0.373913 | 0.424348 | 0.401042 | 0.388087 | 0.026980 |
| 4 | 2.616163 | 0.072099 | 0.001750 | 0.000432 | 2 | svd | 0.00007 | 0.364583 | 0.416667 | 0.385417 | 0.387435 | 0.388525 | 0.018547 | 29 | 0.353043 | 0.373913 | 0.424348 | 0.401042 | 0.388087 | 0.026980 |
| 26 | 2.815629 | 0.050155 | 0.002249 | 0.000432 | 2 | eigen | NaN | 0.364583 | 0.416667 | 0.385417 | 0.387435 | 0.388525 | 0.018547 | 29 | 0.351304 | 0.373913 | 0.424348 | 0.399306 | 0.387218 | 0.027348 |
| 0 | 3.246334 | 0.148674 | 0.001500 | 0.000499 | 2 | svd | 0.0001 | 0.364583 | 0.416667 | 0.385417 | 0.387435 | 0.388525 | 0.018547 | 29 | 0.353043 | 0.373913 | 0.424348 | 0.401042 | 0.388087 | 0.026980 |
In total there are 35 different configurations tested. The best mean test score is 0.7144 There are 7 configurations with this maximum score There are 14 configurations within 0.02 of this maximum score The describe of the configurations within 0.02 of this maximum score is as follows:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_LDA__solver | param_LDA__tol | split0_test_score | split1_test_score | split2_test_score | split3_test_score | mean_test_score | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 14.000000 | 14.000000 | 14.000000 | 1.400000e+01 | 14.0 | 14 | 10.00000 | 14.000000 | 14.000000 | 14.000000 | 14.000000 | 14.000000 | 14.000000 | 14.000000 | 14.000000 | 14.000000 | 14.000000 | 14.000000 | 14.000000 | 14.000000 |
| unique | NaN | NaN | NaN | NaN | 2.0 | 3 | 5.00000 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| top | NaN | NaN | NaN | NaN | 10.0 | svd | 0.00007 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| freq | NaN | NaN | NaN | NaN | 7.0 | 10 | 2.00000 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| mean | 2.823481 | 0.099315 | 0.004602 | 2.855871e-03 | NaN | NaN | NaN | 0.750000 | 0.664062 | 0.720610 | 0.712042 | 0.711679 | 0.034183 | 5.214286 | 0.769565 | 0.714783 | 0.750683 | 0.750868 | 0.746475 | 0.019922 |
| std | 0.101448 | 0.025941 | 0.003257 | 5.290822e-03 | NaN | NaN | NaN | 0.005405 | 0.008107 | 0.019765 | 0.021733 | 0.002900 | 0.000135 | 4.676925 | 0.002707 | 0.003610 | 0.001296 | 0.006306 | 0.003449 | 0.000241 |
| min | 2.729544 | 0.051201 | 0.002250 | 6.529362e-07 | NaN | NaN | NaN | 0.744792 | 0.656250 | 0.697917 | 0.691099 | 0.707986 | 0.034098 | 1.000000 | 0.766957 | 0.711304 | 0.749565 | 0.744792 | 0.743154 | 0.019676 |
| 25% | 2.747327 | 0.082550 | 0.002907 | 4.614444e-04 | NaN | NaN | NaN | 0.744792 | 0.656250 | 0.703125 | 0.691099 | 0.709288 | 0.034098 | 1.000000 | 0.766957 | 0.711304 | 0.749565 | 0.744792 | 0.743154 | 0.019676 |
| 50% | 2.798684 | 0.100885 | 0.003658 | 8.283594e-04 | NaN | NaN | NaN | 0.750000 | 0.664062 | 0.721354 | 0.712042 | 0.711865 | 0.034139 | 4.500000 | 0.769565 | 0.714783 | 0.750435 | 0.750868 | 0.746413 | 0.019940 |
| 75% | 2.846303 | 0.119871 | 0.004249 | 1.647068e-03 | NaN | NaN | NaN | 0.755208 | 0.671875 | 0.739583 | 0.732984 | 0.714442 | 0.034180 | 8.000000 | 0.772174 | 0.718261 | 0.751304 | 0.756944 | 0.749671 | 0.020154 |
| max | 3.073829 | 0.137008 | 0.013246 | 1.774776e-02 | NaN | NaN | NaN | 0.755208 | 0.671875 | 0.739583 | 0.732984 | 0.714442 | 0.034488 | 13.000000 | 0.772174 | 0.718261 | 0.753043 | 0.756944 | 0.750106 | 0.020154 |
####################################################
# GRID SEARCH RESULTS FOR SUBJECT E
####################################################
Best estimator has accuracy of 0.7343 +- 0.0171 with the following parameters
{'CSP__n_components': 10, 'LDA__solver': 'svd', 'LDA__tol': 0.0001}
Top 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_LDA__solver | param_LDA__tol | split0_test_score | split1_test_score | split2_test_score | split3_test_score | mean_test_score | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 34 | 2.668360 | 0.090464 | 0.003749 | 8.288804e-04 | 10 | eigen | NaN | 0.706806 | 0.738220 | 0.738220 | 0.753927 | 0.734293 | 0.017116 | 1 | 0.787086 | 0.783595 | 0.736475 | 0.790576 | 0.774433 | 0.022054 |
| 24 | 2.820031 | 0.079210 | 0.005749 | 2.585159e-03 | 10 | svd | 0.00007 | 0.706806 | 0.738220 | 0.738220 | 0.753927 | 0.734293 | 0.017116 | 1 | 0.785340 | 0.783595 | 0.736475 | 0.790576 | 0.773997 | 0.021815 |
| 23 | 2.924817 | 0.073776 | 0.003749 | 4.327406e-04 | 10 | svd | 0.0004 | 0.706806 | 0.738220 | 0.738220 | 0.753927 | 0.734293 | 0.017116 | 1 | 0.785340 | 0.783595 | 0.736475 | 0.790576 | 0.773997 | 0.021815 |
| 22 | 2.975696 | 0.111334 | 0.003000 | 2.384186e-07 | 10 | svd | 0.001 | 0.706806 | 0.738220 | 0.738220 | 0.753927 | 0.734293 | 0.017116 | 1 | 0.785340 | 0.783595 | 0.736475 | 0.790576 | 0.773997 | 0.021815 |
| 21 | 3.052873 | 0.155479 | 0.003500 | 8.659283e-04 | 10 | svd | 0.00001 | 0.706806 | 0.738220 | 0.738220 | 0.753927 | 0.734293 | 0.017116 | 1 | 0.785340 | 0.783595 | 0.736475 | 0.790576 | 0.773997 | 0.021815 |
| 20 | 2.979868 | 0.068541 | 0.007247 | 6.829423e-03 | 10 | svd | 0.0001 | 0.706806 | 0.738220 | 0.738220 | 0.753927 | 0.734293 | 0.017116 | 1 | 0.785340 | 0.783595 | 0.736475 | 0.790576 | 0.773997 | 0.021815 |
| 33 | 2.914200 | 0.060271 | 0.007511 | 7.238833e-03 | 10 | lsqr | NaN | 0.706806 | 0.738220 | 0.738220 | 0.753927 | 0.734293 | 0.017116 | 1 | 0.787086 | 0.783595 | 0.736475 | 0.790576 | 0.774433 | 0.022054 |
| 32 | 2.955932 | 0.153432 | 0.002499 | 4.995468e-04 | 6 | eigen | NaN | 0.691099 | 0.722513 | 0.722513 | 0.664921 | 0.700262 | 0.024099 | 8 | 0.731239 | 0.706806 | 0.734729 | 0.741710 | 0.728621 | 0.013147 |
| 31 | 2.883022 | 0.129580 | 0.003499 | 5.000235e-04 | 6 | lsqr | NaN | 0.691099 | 0.722513 | 0.722513 | 0.664921 | 0.700262 | 0.024099 | 8 | 0.731239 | 0.706806 | 0.734729 | 0.741710 | 0.728621 | 0.013147 |
| 19 | 2.939181 | 0.091919 | 0.007998 | 9.243115e-03 | 6 | svd | 0.00007 | 0.691099 | 0.722513 | 0.722513 | 0.664921 | 0.700262 | 0.024099 | 8 | 0.731239 | 0.706806 | 0.736475 | 0.741710 | 0.729058 | 0.013370 |
Worst 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_LDA__solver | param_LDA__tol | split0_test_score | split1_test_score | split2_test_score | split3_test_score | mean_test_score | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 5 | 3.101356 | 0.078158 | 0.001251 | 0.000432 | 3 | svd | 0.0001 | 0.659686 | 0.612565 | 0.586387 | 0.649215 | 0.626963 | 0.029239 | 22 | 0.670157 | 0.626527 | 0.638743 | 0.687609 | 0.655759 | 0.024320 |
| 27 | 2.789545 | 0.113972 | 0.001999 | 0.000707 | 3 | lsqr | NaN | 0.659686 | 0.612565 | 0.586387 | 0.649215 | 0.626963 | 0.029239 | 22 | 0.670157 | 0.626527 | 0.638743 | 0.687609 | 0.655759 | 0.024320 |
| 28 | 2.864160 | 0.086134 | 0.001749 | 0.000829 | 3 | eigen | NaN | 0.659686 | 0.612565 | 0.586387 | 0.649215 | 0.626963 | 0.029239 | 22 | 0.670157 | 0.626527 | 0.638743 | 0.687609 | 0.655759 | 0.024320 |
| 1 | 3.146745 | 0.082234 | 0.002250 | 0.000433 | 2 | svd | 0.00001 | 0.638743 | 0.534031 | 0.544503 | 0.602094 | 0.579843 | 0.042755 | 29 | 0.572426 | 0.596859 | 0.575916 | 0.635253 | 0.595113 | 0.024987 |
| 2 | 3.118679 | 0.088869 | 0.001500 | 0.000866 | 2 | svd | 0.001 | 0.638743 | 0.534031 | 0.544503 | 0.602094 | 0.579843 | 0.042755 | 29 | 0.572426 | 0.596859 | 0.575916 | 0.635253 | 0.595113 | 0.024987 |
| 3 | 3.020644 | 0.161201 | 0.001499 | 0.000500 | 2 | svd | 0.0004 | 0.638743 | 0.534031 | 0.544503 | 0.602094 | 0.579843 | 0.042755 | 29 | 0.572426 | 0.596859 | 0.575916 | 0.635253 | 0.595113 | 0.024987 |
| 25 | 2.802779 | 0.112232 | 0.001250 | 0.000433 | 2 | lsqr | NaN | 0.638743 | 0.534031 | 0.544503 | 0.602094 | 0.579843 | 0.042755 | 29 | 0.572426 | 0.598604 | 0.575916 | 0.635253 | 0.595550 | 0.025029 |
| 4 | 2.984879 | 0.066979 | 0.001750 | 0.000432 | 2 | svd | 0.00007 | 0.638743 | 0.534031 | 0.544503 | 0.602094 | 0.579843 | 0.042755 | 29 | 0.572426 | 0.596859 | 0.575916 | 0.635253 | 0.595113 | 0.024987 |
| 26 | 2.794722 | 0.149503 | 0.001999 | 0.000707 | 2 | eigen | NaN | 0.638743 | 0.534031 | 0.544503 | 0.602094 | 0.579843 | 0.042755 | 29 | 0.572426 | 0.598604 | 0.575916 | 0.635253 | 0.595550 | 0.025029 |
| 0 | 3.167731 | 0.150488 | 0.001999 | 0.000707 | 2 | svd | 0.0001 | 0.638743 | 0.534031 | 0.544503 | 0.602094 | 0.579843 | 0.042755 | 29 | 0.572426 | 0.596859 | 0.575916 | 0.635253 | 0.595113 | 0.024987 |
In total there are 35 different configurations tested. The best mean test score is 0.7343 There are 7 configurations with this maximum score There are 7 configurations within 0.02 of this maximum score The describe of the configurations within 0.02 of this maximum score is as follows:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_LDA__solver | param_LDA__tol | split0_test_score | split1_test_score | split2_test_score | split3_test_score | mean_test_score | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 7.000000 | 7.000000 | 7.000000 | 7.000000e+00 | 7.0 | 7 | 5.00000 | 7.000000 | 7.000000e+00 | 7.000000e+00 | 7.000000 | 7.000000e+00 | 7.000000e+00 | 7.0 | 7.000000 | 7.000000e+00 | 7.000000e+00 | 7.000000e+00 | 7.000000 | 7.000000 |
| unique | NaN | NaN | NaN | NaN | 1.0 | 3 | 5.00000 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| top | NaN | NaN | NaN | NaN | 10.0 | svd | 0.00007 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| freq | NaN | NaN | NaN | NaN | 7.0 | 5 | 1.00000 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| mean | 2.905121 | 0.091296 | 0.004929 | 2.683029e-03 | NaN | NaN | NaN | 0.706806 | 7.382199e-01 | 7.382199e-01 | 0.753927 | 7.342932e-01 | 1.711610e-02 | 1.0 | 0.785839 | 7.835951e-01 | 7.364747e-01 | 7.905759e-01 | 0.774121 | 0.021883 |
| std | 0.126600 | 0.032800 | 0.001882 | 3.081082e-03 | NaN | NaN | NaN | 0.000000 | 1.199178e-16 | 1.199178e-16 | 0.000000 | 1.199178e-16 | 3.747431e-18 | 0.0 | 0.000852 | 1.199178e-16 | 1.199178e-16 | 1.199178e-16 | 0.000213 | 0.000116 |
| min | 2.668360 | 0.060271 | 0.003000 | 2.384186e-07 | NaN | NaN | NaN | 0.706806 | 7.382199e-01 | 7.382199e-01 | 0.753927 | 7.342932e-01 | 1.711610e-02 | 1.0 | 0.785340 | 7.835951e-01 | 7.364747e-01 | 7.905759e-01 | 0.773997 | 0.021815 |
| 25% | 2.867116 | 0.071159 | 0.003625 | 6.308105e-04 | NaN | NaN | NaN | 0.706806 | 7.382199e-01 | 7.382199e-01 | 0.753927 | 7.342932e-01 | 1.711610e-02 | 1.0 | 0.785340 | 7.835951e-01 | 7.364747e-01 | 7.905759e-01 | 0.773997 | 0.021815 |
| 50% | 2.924817 | 0.079210 | 0.003749 | 8.659283e-04 | NaN | NaN | NaN | 0.706806 | 7.382199e-01 | 7.382199e-01 | 0.753927 | 7.342932e-01 | 1.711610e-02 | 1.0 | 0.785340 | 7.835951e-01 | 7.364747e-01 | 7.905759e-01 | 0.773997 | 0.021815 |
| 75% | 2.977782 | 0.100899 | 0.006498 | 4.707291e-03 | NaN | NaN | NaN | 0.706806 | 7.382199e-01 | 7.382199e-01 | 0.753927 | 7.342932e-01 | 1.711610e-02 | 1.0 | 0.786213 | 7.835951e-01 | 7.364747e-01 | 7.905759e-01 | 0.774215 | 0.021934 |
| max | 3.052873 | 0.155479 | 0.007511 | 7.238833e-03 | NaN | NaN | NaN | 0.706806 | 7.382199e-01 | 7.382199e-01 | 0.753927 | 7.342932e-01 | 1.711610e-02 | 1.0 | 0.787086 | 7.835951e-01 | 7.364747e-01 | 7.905759e-01 | 0.774433 | 0.022054 |
####################################################
# TEST RESULTS FOR BEST FOUND GRID SEARCH
####################################################
# Configure global parameters for all experiments
subject_ids_to_test = ["B", "C", "E"] # Subjects with three recordings
best_found_csp_components = [4, 10 , 10]
best_found_solver = ["svd", "svd", "svd"]
best_found_tol = [0.0001, 0.0001, 0.0001]
# Loop over all found results
for i in range(len(subject_ids_to_test)):
print("\n\n")
print("####################################################")
print(f"# TEST RESULTS FOR SUBJECT {subject_ids_to_test[i]}")
print("####################################################")
print("\n\n")
# Open train and test data from file
with open(f"saved_variables/2/samesubject_samesession/subject{subject_ids_to_test[i]}/testdata-x_csplda_subject{subject_ids_to_test[i]}.pickle", 'rb') as f:
X_test = pickle.load(f)
with open(f"saved_variables/2/samesubject_samesession/subject{subject_ids_to_test[i]}/testdata-y_csplda_subject{subject_ids_to_test[i]}.pickle", 'rb') as f:
y_test = pickle.load(f)
with open(f"saved_variables/2/samesubject_samesession/subject{subject_ids_to_test[i]}/traindata-x_csplda_subject{subject_ids_to_test[i]}.pickle", 'rb') as f:
X_train = pickle.load(f)
with open(f"saved_variables/2/samesubject_samesession/subject{subject_ids_to_test[i]}/traindata-y_csplda_subject{subject_ids_to_test[i]}.pickle", 'rb') as f:
y_train = pickle.load(f)
# Make the classifier
csp = CSP(norm_trace=False,
component_order="mutual_info",
cov_est= "epoch",
n_components= best_found_csp_components[i])
lda = LinearDiscriminantAnalysis(shrinkage= None,
priors=[1/3, 1/3, 1/3],
solver= best_found_solver[i],
tol= best_found_tol[i])
# Configure the pipeline
pipeline = Pipeline([('CSP', csp), ('LDA', lda)])
# Fit the pipeline
with io.capture_output():
pipeline.fit(X_train, y_train)
# Get accuracy for single fit
y_pred = pipeline.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
# Print accuracy results and CM
print(f"Test accuracy for subject {subject_ids_to_test[i]}: {np.round(accuracy, 4)}")
ConfusionMatrixDisplay.from_predictions(y_true= y_test, y_pred= y_pred)
plt.show()
# plot CSP patterns estimated on train data for visualization
pipeline['CSP'].plot_patterns(CLA_dataset.get_last_raw_mne_data_for_subject(subject_id= subject_ids_to_test[i]).info, ch_type='eeg', units='Patterns (AU)', size=1.5)
plt.show()
# Remove unsused variables
del subject_ids_to_test
del best_found_csp_components
del best_found_solver
del best_found_tol
del i
del f
del X_test
del y_test
del X_train
del y_train
del csp
del lda
del pipeline
del y_pred
del accuracy
#################################################### # TEST RESULTS FOR SUBJECT B #################################################### Test accuracy for subject B: 0.6094
Reading 0 ... 667799 = 0.000 ... 3338.995 secs...
#################################################### # TEST RESULTS FOR SUBJECT C #################################################### Test accuracy for subject C: 0.724
Reading 0 ... 669399 = 0.000 ... 3346.995 secs...
#################################################### # TEST RESULTS FOR SUBJECT E #################################################### Test accuracy for subject E: 0.7277
Reading 0 ... 666999 = 0.000 ... 3334.995 secs...
This experiment works as follows:
####################################################
# GRID SEARCHING BEST PIPELINE FOR EACH SUBJECT
####################################################
# Configure global parameters for all experiments
subject_ids_to_test = ["B", "C", "E"] # Subjects with three recordings
start_offset = -1 # One second before visual queue
end_offset = 1 # One second after visual queue
baseline = (None, 0) # Baseline correction using data before the visual queue
filter_lower_bound = 2 # Filter out any frequency below this
filter_upper_bound = 32 # Filter out any frequency above this
do_experiment = False # Long experiment disabled per default
if do_experiment:
# Loop over all subjects and perform the grid search for finding the best parameters
for subject_id in subject_ids_to_test:
# Get MNE raw object for latest recording of that subject
mne_raw = CLA_dataset.get_last_raw_mne_data_for_subject(subject_id= subject_id)
# Get epochs for that MNE raw
mne_epochs = CLA_dataset.get_usefull_epochs_from_raw(mne_raw,
start_offset= start_offset,
end_offset= end_offset,
baseline= baseline)
# Only keep epochs from the MI tasks
mne_epochs = mne_epochs['task/neutral', 'task/left', 'task/right']
# Load epochs into memory
mne_epochs.load_data()
# Get the labels
labels = mne_epochs.events[:, -1]
# Use a fixed filter
mne_epochs.filter(l_freq= filter_lower_bound,
h_freq= filter_upper_bound,
picks= "all",
phase= "minimum",
fir_window= "blackman",
fir_design= "firwin",
pad= 'median',
n_jobs= -1,
verbose= False)
# Get a half second window
mne_epochs_data = mne_epochs.get_data(tmin= 0.1, tmax= 0.6)
# Create a test and train split
X_train, X_test, y_train, y_test = train_test_split(mne_epochs_data,
labels,
test_size = 0.2,
shuffle= True,
stratify= labels,
random_state= 1998)
# Configure the pipeline components by specifying the default parameters
csp = CSP(norm_trace=False,
component_order="mutual_info",
cov_est= "epoch")
svm = SVC()
# Configure the pipeline
pipeline = Pipeline([('CSP', csp), ('SVM', svm)])
# Configure cross validation to use
cv = StratifiedKFold(n_splits=4,
shuffle= True,
random_state= 2022)
# Configure the hyperparameters to test
# NOTE: these are somewhat limited due to limited computational resources
param_grid = [{
"CSP__n_components": [4, 6, 10],
"SVM__C": [0.01, 0.1, 1, 10, 100],
"SVM__kernel": ['rbf', 'sigmoid'],
"SVM__gamma":['scale', 'auto', 10, 1, 0.1, 0.01, 0.001]}
,{
"CSP__n_components": [4, 6, 10],
"SVM__C": [0.01, 0.1, 1, 10, 100],
"SVM__kernel": ['linear']}]
# Configure the grid search
grid_search = GridSearchCV(estimator= pipeline,
param_grid= param_grid,
scoring= "accuracy",
n_jobs= -1,
refit= False, # We will do this manually
cv= cv,
verbose= 10,
return_train_score= True)
# Do the grid search on the training data
grid_search.fit(X= X_train, y= y_train)
# Store the results of the grid search
with open(f"saved_variables/2/samesubject_samesession/subject{subject_id}/gridsearch_cspsvm_subject{subject_id}.pickle", 'wb') as file:
pickle.dump(grid_search, file)
# Store the train and test data so the best model can be retrained later
with open(f"saved_variables/2/samesubject_samesession/subject{subject_id}/testdata-x_cspsvm_subject{subject_id}.pickle", 'wb') as file:
pickle.dump(X_test, file)
with open(f"saved_variables/2/samesubject_samesession/subject{subject_id}/testdata-y_cspsvm_subject{subject_id}.pickle", 'wb') as file:
pickle.dump(y_test, file)
with open(f"saved_variables/2/samesubject_samesession/subject{subject_id}/traindata-x_cspsvm_subject{subject_id}.pickle", 'wb') as file:
pickle.dump(X_train, file)
with open(f"saved_variables/2/samesubject_samesession/subject{subject_id}/traindata-y_cspsvm_subject{subject_id}.pickle", 'wb') as file:
pickle.dump(y_train, file)
# Delete vars after singular experiment
del mne_raw
del mne_epochs
del mne_epochs_data
del csp
del svm
del pipeline
del labels
del cv
del file
del X_train
del X_test
del y_train
del y_test
del grid_search
del param_grid
# Delete vars after all experiments
del subject_id
# Del global vars
del subject_ids_to_test
del filter_lower_bound
del filter_upper_bound
del baseline
del do_experiment
del end_offset
del start_offset
| Subject | CSP + SVM: cross validation accuracy | CSP + SVM: test split accuracy | Config |
|---|---|---|---|
| B | 0.6693 +- 0.02981 | 0.6146 | 4 CSP components | SVM RBF with C 0.1 and Gamma auto |
| C | 0.7262 +- 0.0298 | 0.7448 | 6 CSP components | SVM RBF with C 100 and Gamma 0.001 |
| E | 0.7356 +- 0.0159 | 0.7016 | 6 CSP components | SVM sigmoid with C 100 and Gamma 0.01 |
####################################################
# GRID SEARCH RESULTS
####################################################
# Configure global parameters for all experiments
subject_ids_to_test = ["B", "C", "E"] # Subjects with three recordings
# Loop over all found results
for subject_id in subject_ids_to_test:
print("\n\n")
print("####################################################")
print(f"# GRID SEARCH RESULTS FOR SUBJECT {subject_id}")
print("####################################################")
print("\n\n")
# Open from file
with open(f"saved_variables/2/samesubject_samesession/subject{subject_id}/gridsearch_cspsvm_subject{subject_id}.pickle", 'rb') as f:
grid_search = pickle.load(f)
# Print the results
print(f"Best estimator has accuracy of {np.round(grid_search.best_score_, 4)} +- {np.round(grid_search.cv_results_['std_test_score'][grid_search.best_index_], 4)} with the following parameters")
print(grid_search.best_params_)
# Get grid search results
grid_search_results = pd.DataFrame(grid_search.cv_results_)
# Keep relevant columns and sort on rank
grid_search_results.drop(labels='params', axis=1, inplace= True)
grid_search_results.sort_values(by=['rank_test_score'], inplace=True)
# Display grid search resulst
print("\n\n Top 10 grid search results: ")
display(grid_search_results.head(10))
print("\n\n Worst 10 grid search results: ")
display(grid_search_results.tail(10))
# Display some statistics
print(f"\n\nIn total there are {len(grid_search_results)} different configurations tested.")
max_score = grid_search_results['mean_test_score'].max()
print(f"The best mean test score is {round(max_score, 4)}")
shared_first_place_count = len(grid_search_results[grid_search_results['mean_test_score'].between(max_score, max_score)])
print(f"There are {shared_first_place_count} configurations with this maximum score")
close_first_place_count = len(grid_search_results[grid_search_results['mean_test_score'].between(max_score-0.02, max_score)])
print(f"There are {close_first_place_count} configurations within 0.02 of this maximum score")
# Display statistics for best classifiers
print("\n\nThe describe of the configurations within 0.02 of this maximum score is as follows:")
display(grid_search_results[grid_search_results['mean_test_score'].between(max_score-0.02, max_score)].describe(include="all"))
# Remove unsused variables
del f
del grid_search
del max_score
del shared_first_place_count
del close_first_place_count
del grid_search_results
del subject_ids_to_test
del subject_id
####################################################
# GRID SEARCH RESULTS FOR SUBJECT B
####################################################
Best estimator has accuracy of 0.6693 +- 0.0298 with the following parameters
{'CSP__n_components': 4, 'SVM__C': 0.1, 'SVM__gamma': 'auto', 'SVM__kernel': 'rbf'}
Top 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_SVM__C | param_SVM__gamma | param_SVM__kernel | split0_test_score | split1_test_score | ... | split3_test_score | mean_test_score | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 16 | 2.062838 | 0.029235 | 0.010996 | 4.878850e-07 | 4 | 0.1 | auto | rbf | 0.630208 | 0.713542 | ... | 0.661458 | 0.669271 | 0.029806 | 1 | 0.696181 | 0.673611 | 0.677083 | 0.699653 | 0.686632 | 0.011417 |
| 52 | 2.095329 | 0.047012 | 0.010246 | 8.289531e-04 | 4 | 10 | 0.01 | rbf | 0.635417 | 0.692708 | ... | 0.651042 | 0.667969 | 0.025349 | 2 | 0.697917 | 0.668403 | 0.689236 | 0.694444 | 0.687500 | 0.011450 |
| 50 | 2.056091 | 0.033294 | 0.008747 | 4.331534e-04 | 4 | 10 | 0.1 | rbf | 0.630208 | 0.682292 | ... | 0.677083 | 0.667969 | 0.021904 | 2 | 0.713542 | 0.694444 | 0.689236 | 0.701389 | 0.699653 | 0.009104 |
| 212 | 2.063338 | 0.040255 | 0.002749 | 4.330503e-04 | 4 | 1 | NaN | linear | 0.625000 | 0.697917 | ... | 0.661458 | 0.666667 | 0.027313 | 4 | 0.703125 | 0.668403 | 0.682292 | 0.689236 | 0.685764 | 0.012519 |
| 36 | 2.077584 | 0.031576 | 0.009497 | 4.999638e-04 | 4 | 1 | 0.1 | rbf | 0.614583 | 0.687500 | ... | 0.666667 | 0.665365 | 0.030895 | 5 | 0.704861 | 0.678819 | 0.682292 | 0.697917 | 0.690972 | 0.010772 |
| 53 | 2.071336 | 0.031652 | 0.003748 | 4.328437e-04 | 4 | 10 | 0.01 | sigmoid | 0.630208 | 0.697917 | ... | 0.645833 | 0.665365 | 0.028138 | 5 | 0.704861 | 0.668403 | 0.678819 | 0.699653 | 0.687934 | 0.014903 |
| 69 | 2.080583 | 0.027161 | 0.004249 | 4.327411e-04 | 4 | 100 | 0.001 | sigmoid | 0.630208 | 0.697917 | ... | 0.645833 | 0.665365 | 0.028138 | 5 | 0.704861 | 0.668403 | 0.678819 | 0.699653 | 0.687934 | 0.014903 |
| 68 | 2.096828 | 0.064931 | 0.009498 | 8.665823e-04 | 4 | 100 | 0.001 | rbf | 0.630208 | 0.692708 | ... | 0.651042 | 0.665365 | 0.025878 | 5 | 0.694444 | 0.663194 | 0.680556 | 0.696181 | 0.683594 | 0.013243 |
| 67 | 2.056341 | 0.028035 | 0.003748 | 4.323968e-04 | 4 | 100 | 0.01 | sigmoid | 0.625000 | 0.692708 | ... | 0.661458 | 0.665365 | 0.025878 | 5 | 0.703125 | 0.668403 | 0.684028 | 0.687500 | 0.685764 | 0.012337 |
| 66 | 2.057090 | 0.030234 | 0.009748 | 1.298875e-03 | 4 | 100 | 0.01 | rbf | 0.625000 | 0.692708 | ... | 0.666667 | 0.665365 | 0.025080 | 5 | 0.704861 | 0.678819 | 0.668403 | 0.689236 | 0.685330 | 0.013469 |
10 rows × 21 columns
Worst 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_SVM__C | param_SVM__gamma | param_SVM__kernel | split0_test_score | split1_test_score | ... | split3_test_score | mean_test_score | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 89 | 2.061839 | 0.017805 | 0.007498 | 5.000853e-04 | 6 | 0.1 | 10 | sigmoid | 0.380208 | 0.416667 | ... | 0.406250 | 0.404948 | 0.014903 | 216 | 0.442708 | 0.390625 | 0.388889 | 0.369792 | 0.398003 | 0.027074 |
| 33 | 2.057590 | 0.031457 | 0.005248 | 4.327061e-04 | 4 | 1 | 10 | sigmoid | 0.338542 | 0.520833 | ... | 0.338542 | 0.399740 | 0.074424 | 217 | 0.394097 | 0.442708 | 0.342014 | 0.322917 | 0.375434 | 0.046768 |
| 47 | 2.070586 | 0.019985 | 0.005498 | 4.995478e-04 | 4 | 10 | 10 | sigmoid | 0.364583 | 0.510417 | ... | 0.307292 | 0.398438 | 0.074435 | 218 | 0.397569 | 0.444444 | 0.362847 | 0.314236 | 0.379774 | 0.047646 |
| 61 | 2.058839 | 0.027923 | 0.005249 | 4.335321e-04 | 4 | 100 | 10 | sigmoid | 0.359375 | 0.510417 | ... | 0.302083 | 0.393229 | 0.076236 | 219 | 0.397569 | 0.446181 | 0.348958 | 0.324653 | 0.379340 | 0.046674 |
| 74 | 2.081832 | 0.023893 | 0.012996 | 2.920019e-07 | 6 | 0.01 | 10 | rbf | 0.453125 | 0.401042 | ... | 0.338542 | 0.381510 | 0.049187 | 220 | 0.663194 | 0.664931 | 0.338542 | 0.336806 | 0.500868 | 0.163197 |
| 88 | 2.069335 | 0.032868 | 0.013995 | 1.731340e-03 | 6 | 0.1 | 10 | rbf | 0.453125 | 0.401042 | ... | 0.338542 | 0.381510 | 0.049187 | 220 | 0.663194 | 0.664931 | 0.338542 | 0.336806 | 0.500868 | 0.163197 |
| 19 | 2.069086 | 0.035467 | 0.006498 | 4.997254e-04 | 4 | 0.1 | 10 | sigmoid | 0.359375 | 0.380208 | ... | 0.369792 | 0.380208 | 0.019488 | 222 | 0.385417 | 0.388889 | 0.413194 | 0.342014 | 0.382378 | 0.025644 |
| 172 | 2.081583 | 0.039720 | 0.015496 | 8.666903e-04 | 10 | 1 | 10 | rbf | 0.421875 | 0.390625 | ... | 0.333333 | 0.376302 | 0.033222 | 223 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 0.000000 |
| 144 | 2.066337 | 0.038231 | 0.015496 | 8.663757e-04 | 10 | 0.01 | 10 | rbf | 0.333333 | 0.343750 | ... | 0.338542 | 0.337240 | 0.004319 | 224 | 0.673611 | 0.673611 | 0.338542 | 0.336806 | 0.505642 | 0.167970 |
| 158 | 2.077584 | 0.039322 | 0.014996 | 7.078143e-04 | 10 | 0.1 | 10 | rbf | 0.333333 | 0.343750 | ... | 0.338542 | 0.337240 | 0.004319 | 224 | 0.673611 | 0.673611 | 0.338542 | 0.336806 | 0.505642 | 0.167970 |
10 rows × 21 columns
In total there are 225 different configurations tested. The best mean test score is 0.6693 There are 1 configurations with this maximum score There are 41 configurations within 0.02 of this maximum score The describe of the configurations within 0.02 of this maximum score is as follows:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_SVM__C | param_SVM__gamma | param_SVM__kernel | split0_test_score | split1_test_score | ... | split3_test_score | mean_test_score | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 41.000000 | 41.000000 | 41.000000 | 4.100000e+01 | 41.0 | 41.0 | 35.00 | 41 | 41.000000 | 41.000000 | ... | 41.000000 | 41.000000 | 41.000000 | 41.000000 | 41.000000 | 41.000000 | 41.000000 | 41.000000 | 41.000000 | 41.000000 |
| unique | NaN | NaN | NaN | NaN | 2.0 | 5.0 | 6.00 | 3 | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| top | NaN | NaN | NaN | NaN | 4.0 | 0.1 | 0.01 | rbf | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| freq | NaN | NaN | NaN | NaN | 30.0 | 12.0 | 9.00 | 23 | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| mean | 2.066697 | 0.033610 | 0.007784 | 6.481829e-04 | NaN | NaN | NaN | NaN | 0.616489 | 0.696519 | ... | 0.657393 | 0.659966 | 0.030955 | 19.756098 | 0.709646 | 0.681741 | 0.684282 | 0.697282 | 0.693238 | 0.012254 |
| std | 0.012887 | 0.008334 | 0.003400 | 4.352666e-04 | NaN | NaN | NaN | NaN | 0.014720 | 0.013022 | ... | 0.015341 | 0.005172 | 0.007044 | 12.287759 | 0.013757 | 0.013606 | 0.014161 | 0.010054 | 0.011620 | 0.002768 |
| min | 2.049093 | 0.017749 | 0.002749 | 3.097148e-07 | NaN | NaN | NaN | NaN | 0.588542 | 0.677083 | ... | 0.630208 | 0.649740 | 0.017806 | 1.000000 | 0.673611 | 0.661458 | 0.659722 | 0.663194 | 0.664931 | 0.004490 |
| 25% | 2.056590 | 0.028129 | 0.004250 | 4.328437e-04 | NaN | NaN | NaN | NaN | 0.604167 | 0.687500 | ... | 0.645833 | 0.656250 | 0.025878 | 5.000000 | 0.699653 | 0.671875 | 0.677083 | 0.689236 | 0.685764 | 0.011142 |
| 50% | 2.062839 | 0.032509 | 0.008997 | 4.341517e-04 | NaN | NaN | NaN | NaN | 0.619792 | 0.692708 | ... | 0.656250 | 0.660156 | 0.029232 | 21.000000 | 0.704861 | 0.678819 | 0.682292 | 0.697917 | 0.690538 | 0.012905 |
| 75% | 2.073335 | 0.038211 | 0.010497 | 8.289704e-04 | NaN | NaN | NaN | NaN | 0.630208 | 0.697917 | ... | 0.671875 | 0.665365 | 0.036644 | 27.000000 | 0.723958 | 0.687500 | 0.690972 | 0.703125 | 0.700087 | 0.014178 |
| max | 2.096828 | 0.064931 | 0.013246 | 1.920093e-03 | NaN | NaN | NaN | NaN | 0.640625 | 0.729167 | ... | 0.687500 | 0.669271 | 0.041991 | 40.000000 | 0.732639 | 0.718750 | 0.725694 | 0.722222 | 0.724392 | 0.016804 |
11 rows × 21 columns
####################################################
# GRID SEARCH RESULTS FOR SUBJECT C
####################################################
Best estimator has accuracy of 0.7262 +- 0.0298 with the following parameters
{'CSP__n_components': 6, 'SVM__C': 100, 'SVM__gamma': 0.001, 'SVM__kernel': 'rbf'}
Top 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_SVM__C | param_SVM__gamma | param_SVM__kernel | split0_test_score | split1_test_score | ... | split3_test_score | mean_test_score | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 138 | 2.138314 | 0.064616 | 0.009747 | 4.325340e-04 | 6 | 100 | 0.001 | rbf | 0.755208 | 0.677083 | ... | 0.743455 | 0.726228 | 0.029835 | 1 | 0.760000 | 0.725217 | 0.760000 | 0.756944 | 0.750540 | 0.014673 |
| 122 | 2.131567 | 0.064912 | 0.009497 | 4.990702e-04 | 6 | 10 | 0.01 | rbf | 0.729167 | 0.677083 | ... | 0.748691 | 0.721027 | 0.026593 | 2 | 0.760000 | 0.732174 | 0.766957 | 0.756944 | 0.754019 | 0.013124 |
| 123 | 2.122069 | 0.059648 | 0.004749 | 4.328783e-04 | 6 | 10 | 0.01 | sigmoid | 0.760417 | 0.677083 | ... | 0.738220 | 0.721013 | 0.031382 | 3 | 0.756522 | 0.716522 | 0.760000 | 0.758681 | 0.747931 | 0.018177 |
| 216 | 2.111574 | 0.054261 | 0.003998 | 5.196212e-07 | 6 | 0.1 | NaN | linear | 0.755208 | 0.671875 | ... | 0.743455 | 0.719718 | 0.032564 | 4 | 0.754783 | 0.718261 | 0.760000 | 0.758681 | 0.747931 | 0.017237 |
| 139 | 2.113572 | 0.055252 | 0.004498 | 4.998447e-04 | 6 | 100 | 0.001 | sigmoid | 0.755208 | 0.671875 | ... | 0.743455 | 0.719718 | 0.032564 | 4 | 0.753043 | 0.718261 | 0.760000 | 0.758681 | 0.747496 | 0.017080 |
| 98 | 2.128318 | 0.054933 | 0.009747 | 4.328781e-04 | 6 | 1 | scale | rbf | 0.739583 | 0.677083 | ... | 0.743455 | 0.717114 | 0.026825 | 6 | 0.765217 | 0.739130 | 0.761739 | 0.751736 | 0.754456 | 0.010138 |
| 217 | 2.111574 | 0.063782 | 0.004248 | 8.283779e-04 | 6 | 1 | NaN | linear | 0.760417 | 0.661458 | ... | 0.738220 | 0.717107 | 0.037065 | 7 | 0.758261 | 0.725217 | 0.763478 | 0.753472 | 0.750107 | 0.014799 |
| 219 | 2.198046 | 0.064861 | 0.003749 | 4.330502e-04 | 6 | 100 | NaN | linear | 0.760417 | 0.656250 | ... | 0.738220 | 0.717107 | 0.038852 | 8 | 0.761739 | 0.723478 | 0.763478 | 0.750000 | 0.749674 | 0.015988 |
| 218 | 2.132316 | 0.058110 | 0.004499 | 8.661691e-04 | 6 | 10 | NaN | linear | 0.765625 | 0.651042 | ... | 0.738220 | 0.717107 | 0.042359 | 8 | 0.763478 | 0.725217 | 0.763478 | 0.750000 | 0.750543 | 0.015623 |
| 106 | 2.116072 | 0.060513 | 0.009747 | 4.333615e-04 | 6 | 1 | 0.1 | rbf | 0.744792 | 0.677083 | ... | 0.743455 | 0.715812 | 0.029258 | 10 | 0.772174 | 0.735652 | 0.763478 | 0.755208 | 0.756628 | 0.013515 |
10 rows × 21 columns
Worst 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_SVM__C | param_SVM__gamma | param_SVM__kernel | split0_test_score | split1_test_score | ... | split3_test_score | mean_test_score | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 27 | 2.124819 | 0.061106 | 0.004998 | 4.256623e-07 | 4 | 0.1 | 0.001 | sigmoid | 0.338542 | 0.338542 | ... | 0.335079 | 0.337676 | 0.001500 | 167 | 0.337391 | 0.337391 | 0.337391 | 0.338542 | 0.337679 | 0.000498 |
| 159 | 2.140564 | 0.064930 | 0.006498 | 5.003216e-04 | 10 | 0.1 | 10 | sigmoid | 0.343750 | 0.338542 | ... | 0.329843 | 0.337669 | 0.004994 | 217 | 0.339130 | 0.358261 | 0.342609 | 0.342014 | 0.345503 | 0.007482 |
| 161 | 2.137314 | 0.051184 | 0.008997 | 7.068027e-04 | 10 | 0.1 | 1 | sigmoid | 0.343750 | 0.338542 | ... | 0.329843 | 0.337669 | 0.004994 | 217 | 0.339130 | 0.354783 | 0.342609 | 0.343750 | 0.345068 | 0.005861 |
| 45 | 2.121568 | 0.065900 | 0.005248 | 4.331191e-04 | 4 | 10 | auto | sigmoid | 0.343750 | 0.328125 | ... | 0.298429 | 0.337628 | 0.029485 | 219 | 0.326957 | 0.354783 | 0.422609 | 0.338542 | 0.360722 | 0.037072 |
| 119 | 2.129568 | 0.057568 | 0.006748 | 4.330502e-04 | 6 | 10 | 1 | sigmoid | 0.307292 | 0.385417 | ... | 0.314136 | 0.333742 | 0.030765 | 220 | 0.335652 | 0.403478 | 0.323478 | 0.322917 | 0.346381 | 0.033355 |
| 31 | 2.132565 | 0.057902 | 0.006498 | 4.998448e-04 | 4 | 1 | auto | sigmoid | 0.348958 | 0.317708 | ... | 0.335079 | 0.332468 | 0.011352 | 221 | 0.349565 | 0.361739 | 0.353043 | 0.312500 | 0.344212 | 0.018838 |
| 59 | 2.109323 | 0.052323 | 0.007248 | 2.277293e-03 | 4 | 100 | auto | sigmoid | 0.322917 | 0.322917 | ... | 0.303665 | 0.331125 | 0.026523 | 222 | 0.326957 | 0.354783 | 0.427826 | 0.335069 | 0.361159 | 0.039798 |
| 63 | 2.112073 | 0.054514 | 0.006248 | 1.089496e-03 | 4 | 100 | 1 | sigmoid | 0.302083 | 0.369792 | ... | 0.319372 | 0.328541 | 0.025087 | 223 | 0.321739 | 0.373913 | 0.293913 | 0.315972 | 0.326384 | 0.029340 |
| 49 | 2.113823 | 0.061155 | 0.006248 | 4.327061e-04 | 4 | 10 | 1 | sigmoid | 0.276042 | 0.375000 | ... | 0.314136 | 0.315513 | 0.036898 | 224 | 0.304348 | 0.379130 | 0.325217 | 0.312500 | 0.330299 | 0.029157 |
| 175 | 2.128068 | 0.052904 | 0.009247 | 8.287548e-04 | 10 | 1 | 1 | sigmoid | 0.375000 | 0.270833 | ... | 0.308901 | 0.303788 | 0.044901 | 225 | 0.360000 | 0.269565 | 0.335652 | 0.343750 | 0.327242 | 0.034435 |
10 rows × 21 columns
In total there are 225 different configurations tested. The best mean test score is 0.7262 There are 1 configurations with this maximum score There are 27 configurations within 0.02 of this maximum score The describe of the configurations within 0.02 of this maximum score is as follows:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_SVM__C | param_SVM__gamma | param_SVM__kernel | split0_test_score | split1_test_score | ... | split3_test_score | mean_test_score | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 27.000000 | 27.000000 | 27.000000 | 2.700000e+01 | 27.0 | 27.0 | 20.00 | 27 | 27.000000 | 27.000000 | ... | 27.000000 | 27.000000 | 27.000000 | 27.000000 | 27.000000 | 27.000000 | 27.000000 | 27.000000 | 27.000000 | 27.000000 |
| unique | NaN | NaN | NaN | NaN | 2.0 | 5.0 | 5.00 | 3 | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| top | NaN | NaN | NaN | NaN | 6.0 | 10.0 | 0.01 | rbf | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| freq | NaN | NaN | NaN | NaN | 19.0 | 8.0 | 7.00 | 14 | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| mean | 2.138231 | 0.060986 | 0.008192 | 6.971704e-04 | NaN | NaN | NaN | NaN | 0.747492 | 0.662230 | ... | 0.723483 | 0.713375 | 0.033656 | 13.814815 | 0.761031 | 0.734364 | 0.759549 | 0.753344 | 0.752072 | 0.011619 |
| std | 0.037548 | 0.005814 | 0.003710 | 5.522133e-04 | NaN | NaN | NaN | NaN | 0.010651 | 0.011798 | ... | 0.019538 | 0.005212 | 0.005763 | 7.942278 | 0.014688 | 0.013968 | 0.013993 | 0.013464 | 0.012414 | 0.004457 |
| min | 2.111573 | 0.050226 | 0.003749 | 5.196212e-07 | NaN | NaN | NaN | NaN | 0.729167 | 0.645833 | ... | 0.685864 | 0.706622 | 0.024617 | 1.000000 | 0.735652 | 0.716522 | 0.733913 | 0.723958 | 0.730119 | 0.001243 |
| 25% | 2.120570 | 0.055907 | 0.004624 | 4.328266e-04 | NaN | NaN | NaN | NaN | 0.739583 | 0.651042 | ... | 0.706806 | 0.709281 | 0.029095 | 7.500000 | 0.753913 | 0.724348 | 0.753043 | 0.750868 | 0.747714 | 0.009465 |
| 50% | 2.128318 | 0.060513 | 0.008997 | 4.990702e-04 | NaN | NaN | NaN | NaN | 0.744792 | 0.661458 | ... | 0.727749 | 0.711865 | 0.032761 | 13.000000 | 0.760000 | 0.733913 | 0.760000 | 0.755208 | 0.754019 | 0.011538 |
| 75% | 2.142813 | 0.065022 | 0.010372 | 8.289529e-04 | NaN | NaN | NaN | NaN | 0.755208 | 0.671875 | ... | 0.740838 | 0.717107 | 0.037256 | 20.500000 | 0.770435 | 0.740870 | 0.763478 | 0.758681 | 0.757497 | 0.015070 |
| max | 2.300512 | 0.074573 | 0.014995 | 3.081520e-03 | NaN | NaN | NaN | NaN | 0.765625 | 0.677083 | ... | 0.748691 | 0.726228 | 0.045044 | 26.000000 | 0.800000 | 0.777391 | 0.808696 | 0.791667 | 0.794438 | 0.018177 |
11 rows × 21 columns
####################################################
# GRID SEARCH RESULTS FOR SUBJECT E
####################################################
Best estimator has accuracy of 0.7356 +- 0.0159 with the following parameters
{'CSP__n_components': 10, 'SVM__C': 100, 'SVM__gamma': 0.01, 'SVM__kernel': 'sigmoid'}
Top 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_SVM__C | param_SVM__gamma | param_SVM__kernel | split0_test_score | split1_test_score | ... | split3_test_score | mean_test_score | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 222 | 2.182299 | 0.063694 | 0.004499 | 5.003215e-04 | 10 | 1 | NaN | linear | 0.717277 | 0.722513 | ... | 0.748691 | 0.735602 | 0.015923 | 1 | 0.790576 | 0.766143 | 0.752182 | 0.776614 | 0.771379 | 0.014070 |
| 207 | 2.176302 | 0.061567 | 0.005748 | 4.329472e-04 | 10 | 100 | 0.01 | sigmoid | 0.717277 | 0.722513 | ... | 0.748691 | 0.735602 | 0.015923 | 1 | 0.788831 | 0.766143 | 0.752182 | 0.776614 | 0.770942 | 0.013483 |
| 206 | 2.173304 | 0.073917 | 0.009496 | 4.990103e-04 | 10 | 100 | 0.01 | rbf | 0.717277 | 0.722513 | ... | 0.738220 | 0.734293 | 0.016296 | 3 | 0.795812 | 0.809773 | 0.764398 | 0.802792 | 0.793194 | 0.017343 |
| 208 | 2.188048 | 0.082224 | 0.010247 | 4.326374e-04 | 10 | 100 | 0.001 | rbf | 0.712042 | 0.722513 | ... | 0.743455 | 0.732984 | 0.016556 | 4 | 0.787086 | 0.767888 | 0.752182 | 0.785340 | 0.773124 | 0.014232 |
| 184 | 2.167556 | 0.059497 | 0.010747 | 1.298565e-03 | 10 | 10 | auto | rbf | 0.712042 | 0.743455 | ... | 0.748691 | 0.731675 | 0.014981 | 5 | 0.832461 | 0.837696 | 0.821990 | 0.839442 | 0.832897 | 0.006801 |
| 190 | 2.162807 | 0.068671 | 0.009496 | 5.002022e-04 | 10 | 10 | 0.1 | rbf | 0.712042 | 0.743455 | ... | 0.748691 | 0.731675 | 0.014981 | 5 | 0.832461 | 0.837696 | 0.821990 | 0.839442 | 0.832897 | 0.006801 |
| 223 | 2.179551 | 0.071981 | 0.005498 | 2.061071e-03 | 10 | 10 | NaN | linear | 0.717277 | 0.717277 | ... | 0.748691 | 0.729058 | 0.013023 | 7 | 0.787086 | 0.773124 | 0.764398 | 0.780105 | 0.776178 | 0.008404 |
| 192 | 2.186050 | 0.067975 | 0.010247 | 4.331198e-04 | 10 | 10 | 0.01 | rbf | 0.706806 | 0.712042 | ... | 0.748691 | 0.729058 | 0.019721 | 7 | 0.788831 | 0.778360 | 0.752182 | 0.787086 | 0.776614 | 0.014653 |
| 176 | 2.160808 | 0.064378 | 0.010996 | 4.578320e-07 | 10 | 1 | 0.1 | rbf | 0.712042 | 0.712042 | ... | 0.732984 | 0.726440 | 0.015432 | 9 | 0.792321 | 0.801047 | 0.769634 | 0.804538 | 0.791885 | 0.013595 |
| 170 | 2.183050 | 0.071905 | 0.010246 | 4.320182e-04 | 10 | 1 | auto | rbf | 0.712042 | 0.712042 | ... | 0.732984 | 0.726440 | 0.015432 | 9 | 0.792321 | 0.801047 | 0.769634 | 0.804538 | 0.791885 | 0.013595 |
10 rows × 21 columns
Worst 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_SVM__C | param_SVM__gamma | param_SVM__kernel | split0_test_score | split1_test_score | ... | split3_test_score | mean_test_score | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 166 | 2.200044 | 0.073691 | 0.015495 | 1.658282e-03 | 10 | 0.1 | 0.001 | rbf | 0.335079 | 0.335079 | ... | 0.340314 | 0.337696 | 0.002618 | 186 | 0.338569 | 0.338569 | 0.336824 | 0.336824 | 0.337696 | 0.000873 |
| 27 | 2.186049 | 0.064834 | 0.004749 | 4.334633e-04 | 4 | 0.1 | 0.001 | sigmoid | 0.335079 | 0.335079 | ... | 0.340314 | 0.337696 | 0.002618 | 186 | 0.338569 | 0.338569 | 0.336824 | 0.336824 | 0.337696 | 0.000873 |
| 26 | 2.158558 | 0.056883 | 0.011997 | 8.678567e-07 | 4 | 0.1 | 0.001 | rbf | 0.335079 | 0.335079 | ... | 0.340314 | 0.337696 | 0.002618 | 186 | 0.338569 | 0.338569 | 0.336824 | 0.336824 | 0.337696 | 0.000873 |
| 25 | 2.175301 | 0.066266 | 0.004749 | 4.322932e-04 | 4 | 0.1 | 0.01 | sigmoid | 0.335079 | 0.335079 | ... | 0.340314 | 0.337696 | 0.002618 | 186 | 0.338569 | 0.338569 | 0.336824 | 0.336824 | 0.337696 | 0.000873 |
| 181 | 2.170304 | 0.075860 | 0.007247 | 4.333602e-04 | 10 | 1 | 0.001 | sigmoid | 0.335079 | 0.335079 | ... | 0.340314 | 0.337696 | 0.002618 | 186 | 0.338569 | 0.338569 | 0.336824 | 0.336824 | 0.337696 | 0.000873 |
| 88 | 2.168554 | 0.078395 | 0.013996 | 7.070560e-04 | 6 | 0.1 | 10 | rbf | 0.335079 | 0.335079 | ... | 0.340314 | 0.337696 | 0.002618 | 186 | 0.338569 | 0.338569 | 0.336824 | 0.336824 | 0.337696 | 0.000873 |
| 83 | 2.175051 | 0.071675 | 0.005748 | 4.330159e-04 | 6 | 0.01 | 0.001 | sigmoid | 0.335079 | 0.335079 | ... | 0.340314 | 0.337696 | 0.002618 | 186 | 0.338569 | 0.338569 | 0.336824 | 0.336824 | 0.337696 | 0.000873 |
| 82 | 2.178302 | 0.086497 | 0.013496 | 1.499573e-03 | 6 | 0.01 | 0.001 | rbf | 0.335079 | 0.335079 | ... | 0.340314 | 0.337696 | 0.002618 | 186 | 0.338569 | 0.338569 | 0.336824 | 0.336824 | 0.337696 | 0.000873 |
| 80 | 2.176553 | 0.064804 | 0.013745 | 1.918184e-03 | 6 | 0.01 | 0.01 | rbf | 0.335079 | 0.335079 | ... | 0.340314 | 0.337696 | 0.002618 | 186 | 0.338569 | 0.338569 | 0.336824 | 0.336824 | 0.337696 | 0.000873 |
| 167 | 2.194796 | 0.063077 | 0.006998 | 7.074771e-04 | 10 | 0.1 | 0.001 | sigmoid | 0.335079 | 0.335079 | ... | 0.340314 | 0.337696 | 0.002618 | 186 | 0.338569 | 0.338569 | 0.336824 | 0.336824 | 0.337696 | 0.000873 |
10 rows × 21 columns
In total there are 225 different configurations tested. The best mean test score is 0.7356 There are 2 configurations with this maximum score There are 17 configurations within 0.02 of this maximum score The describe of the configurations within 0.02 of this maximum score is as follows:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_SVM__C | param_SVM__gamma | param_SVM__kernel | split0_test_score | split1_test_score | ... | split3_test_score | mean_test_score | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 17.000000 | 17.000000 | 17.000000 | 1.700000e+01 | 17.0 | 17.0 | 13.00 | 17 | 17.000000 | 17.000000 | ... | 17.000000 | 17.000000 | 17.000000 | 17.000000 | 17.000000 | 17.000000 | 17.000000 | 17.000000 | 17.000000 | 17.000000 |
| unique | NaN | NaN | NaN | NaN | 1.0 | 4.0 | 5.00 | 3 | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| top | NaN | NaN | NaN | NaN | 10.0 | 1.0 | 0.01 | rbf | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| freq | NaN | NaN | NaN | NaN | 17.0 | 6.0 | 4.00 | 8 | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| mean | 2.190754 | 0.069968 | 0.007718 | 5.423305e-04 | NaN | NaN | NaN | NaN | 0.708038 | 0.716046 | ... | 0.743148 | 0.727826 | 0.018031 | 8.470588 | 0.792218 | 0.783903 | 0.767580 | 0.796017 | 0.784930 | 0.012464 |
| std | 0.054922 | 0.007008 | 0.002579 | 4.957370e-04 | NaN | NaN | NaN | NaN | 0.009353 | 0.014637 | ... | 0.006537 | 0.005770 | 0.004073 | 4.862068 | 0.018402 | 0.030630 | 0.024821 | 0.024158 | 0.023920 | 0.002978 |
| min | 2.160808 | 0.059497 | 0.004499 | 3.908538e-07 | NaN | NaN | NaN | NaN | 0.680628 | 0.696335 | ... | 0.732984 | 0.715969 | 0.013023 | 1.000000 | 0.773124 | 0.748691 | 0.750436 | 0.774869 | 0.761780 | 0.006801 |
| 25% | 2.167556 | 0.064378 | 0.005748 | 4.321227e-04 | NaN | NaN | NaN | NaN | 0.701571 | 0.701571 | ... | 0.738220 | 0.725131 | 0.015432 | 5.000000 | 0.776614 | 0.760908 | 0.752182 | 0.780105 | 0.769197 | 0.012248 |
| 50% | 2.178052 | 0.068671 | 0.005998 | 4.329472e-04 | NaN | NaN | NaN | NaN | 0.712042 | 0.712042 | ... | 0.743455 | 0.726440 | 0.016296 | 9.000000 | 0.788831 | 0.771379 | 0.752182 | 0.785340 | 0.775305 | 0.013483 |
| 75% | 2.186050 | 0.073917 | 0.010247 | 5.002022e-04 | NaN | NaN | NaN | NaN | 0.712042 | 0.722513 | ... | 0.748691 | 0.731675 | 0.019721 | 13.000000 | 0.792321 | 0.801047 | 0.769634 | 0.804538 | 0.791885 | 0.014070 |
| max | 2.398481 | 0.082224 | 0.010996 | 2.061071e-03 | NaN | NaN | NaN | NaN | 0.717277 | 0.743455 | ... | 0.748691 | 0.735602 | 0.027047 | 16.000000 | 0.832461 | 0.837696 | 0.821990 | 0.848168 | 0.832897 | 0.017343 |
11 rows × 21 columns
####################################################
# TEST RESULTS FOR BEST FOUND GRID SEARCH
####################################################
# Configure global parameters for all experiments
subject_ids_to_test = ["B", "C", "E"] # Subjects with three recordings
best_found_csp_components = [4, 6 , 6]
best_found_svm_kernel = ["rbf", "rbf", "sigmoid"]
best_found_svm_c = [0.1, 100, 100]
best_found_svm_gamma = ["auto", 0.001, 0.01]
# Loop over all found results
for i in range(len(subject_ids_to_test)):
print("\n\n")
print("####################################################")
print(f"# TEST RESULTS FOR SUBJECT {subject_ids_to_test[i]}")
print("####################################################")
print("\n\n")
# Open train and test data from file
with open(f"saved_variables/2/samesubject_samesession/subject{subject_ids_to_test[i]}/testdata-x_cspsvm_subject{subject_ids_to_test[i]}.pickle", 'rb') as f:
X_test = pickle.load(f)
with open(f"saved_variables/2/samesubject_samesession/subject{subject_ids_to_test[i]}/testdata-y_cspsvm_subject{subject_ids_to_test[i]}.pickle", 'rb') as f:
y_test = pickle.load(f)
with open(f"saved_variables/2/samesubject_samesession/subject{subject_ids_to_test[i]}/traindata-x_cspsvm_subject{subject_ids_to_test[i]}.pickle", 'rb') as f:
X_train = pickle.load(f)
with open(f"saved_variables/2/samesubject_samesession/subject{subject_ids_to_test[i]}/traindata-y_cspsvm_subject{subject_ids_to_test[i]}.pickle", 'rb') as f:
y_train = pickle.load(f)
# Make the classifier
csp = CSP(norm_trace=False,
component_order="mutual_info",
cov_est= "epoch",
n_components= best_found_csp_components[i])
svm = SVC(kernel= best_found_svm_kernel[i],
C= best_found_svm_c[i],
gamma= best_found_svm_gamma[i])
# Configure the pipeline
pipeline = Pipeline([('CSP', csp), ('SVM', svm)])
# Fit the pipeline
with io.capture_output():
pipeline.fit(X_train, y_train)
# Get accuracy for single fit
y_pred = pipeline.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
# Print accuracy results and CM
print(f"Test accuracy for subject {subject_ids_to_test[i]}: {np.round(accuracy, 4)}")
ConfusionMatrixDisplay.from_predictions(y_true= y_test, y_pred= y_pred)
plt.show()
# plot CSP patterns estimated on train data for visualization
pipeline['CSP'].plot_patterns(CLA_dataset.get_last_raw_mne_data_for_subject(subject_id= subject_ids_to_test[i]).info, ch_type='eeg', units='Patterns (AU)', size=1.5)
plt.show()
# Remove unsused variables
del subject_ids_to_test
del best_found_csp_components
del best_found_svm_kernel
del best_found_svm_c
del best_found_svm_gamma
del i
del f
del X_test
del y_test
del X_train
del y_train
del csp
del svm
del pipeline
del y_pred
del accuracy
#################################################### # TEST RESULTS FOR SUBJECT B #################################################### Test accuracy for subject B: 0.6146
Reading 0 ... 667799 = 0.000 ... 3338.995 secs...
#################################################### # TEST RESULTS FOR SUBJECT C #################################################### Test accuracy for subject C: 0.7448
Reading 0 ... 669399 = 0.000 ... 3346.995 secs...
#################################################### # TEST RESULTS FOR SUBJECT E #################################################### Test accuracy for subject E: 0.7016
Reading 0 ... 666999 = 0.000 ... 3334.995 secs...
This experiment works as follows:
####################################################
# GRID SEARCHING BEST PIPELINE FOR EACH SUBJECT
####################################################
# Configure global parameters for all experiments
subject_ids_to_test = ["B", "C", "E"] # Subjects with three recordings
start_offset = -1 # One second before visual queue
end_offset = 1 # One second after visual queue
baseline = (None, 0) # Baseline correction using data before the visual queue
filter_lower_bound = 2 # Filter out any frequency below this
filter_upper_bound = 32 # Filter out any frequency above this
do_experiment = False # Long experiment disabled per default
if do_experiment:
# Loop over all subjects and perform the grid search for finding the best parameters
for subject_id in subject_ids_to_test:
# Get MNE raw object for latest recording of that subject
mne_raw = CLA_dataset.get_last_raw_mne_data_for_subject(subject_id= subject_id)
# Get epochs for that MNE raw
mne_epochs = CLA_dataset.get_usefull_epochs_from_raw(mne_raw,
start_offset= start_offset,
end_offset= end_offset,
baseline= baseline)
# Only keep epochs from the MI tasks
mne_epochs = mne_epochs['task/neutral', 'task/left', 'task/right']
# Load epochs into memory
mne_epochs.load_data()
# Get the labels
labels = mne_epochs.events[:, -1]
# Use a fixed filter
mne_epochs.filter(l_freq= filter_lower_bound,
h_freq= filter_upper_bound,
picks= "all",
phase= "minimum",
fir_window= "blackman",
fir_design= "firwin",
pad= 'median',
n_jobs= -1,
verbose= False)
# Get a half second window
mne_epochs_data = mne_epochs.get_data(tmin= 0.1, tmax= 0.6)
# Create a test and train split
X_train, X_test, y_train, y_test = train_test_split(mne_epochs_data,
labels,
test_size = 0.2,
shuffle= True,
stratify= labels,
random_state= 1998)
# Configure the pipeline components by specifying the default parameters
csp = CSP(norm_trace=False,
component_order="mutual_info",
cov_est= "epoch")
rf = RandomForestClassifier(bootstrap= True,
criterion= "gini")
# Configure the pipeline
pipeline = Pipeline([('CSP', csp), ('RF', rf)])
# Configure cross validation to use
cv = StratifiedKFold(n_splits=4,
shuffle= True,
random_state= 2022)
# Configure the hyperparameters to test
# NOTE: these are somewhat limited due to limitedd computational resources
param_grid = [{"CSP__n_components": [4, 6, 10],
"RF__n_estimators": [10, 50, 100, 250, 500],
"RF__max_depth": [None, 3, 10],
"RF__min_samples_split": [2, 5, 10],
"RF__max_features": ["sqrt", "log2", "None", 0.2, 0.4, 0.6]}]
# Configure the grid search
grid_search = GridSearchCV(estimator= pipeline,
param_grid= param_grid,
scoring= "accuracy",
n_jobs= -1,
refit= False, # We will do this manually
cv= cv,
verbose= 10,
return_train_score= True)
# Do the grid search on the training data
grid_search.fit(X= X_train, y= y_train)
# Store the results of the grid search
with open(f"saved_variables/2/samesubject_samesession/subject{subject_id}/gridsearch_csprf_subject{subject_id}.pickle", 'wb') as file:
pickle.dump(grid_search, file)
# Store the train and test data so the best model can be retrained later
with open(f"saved_variables/2/samesubject_samesession/subject{subject_id}/testdata-x_csprf_subject{subject_id}.pickle", 'wb') as file:
pickle.dump(X_test, file)
with open(f"saved_variables/2/samesubject_samesession/subject{subject_id}/testdata-y_csprf_subject{subject_id}.pickle", 'wb') as file:
pickle.dump(y_test, file)
with open(f"saved_variables/2/samesubject_samesession/subject{subject_id}/traindata-x_csprf_subject{subject_id}.pickle", 'wb') as file:
pickle.dump(X_train, file)
with open(f"saved_variables/2/samesubject_samesession/subject{subject_id}/traindata-y_csprf_subject{subject_id}.pickle", 'wb') as file:
pickle.dump(y_train, file)
# Delete vars after singular experiment
del mne_raw
del mne_epochs
del mne_epochs_data
del csp
del rf
del pipeline
del labels
del cv
del file
del X_train
del X_test
del y_train
del y_test
del grid_search
del param_grid
# Delete vars after all experiments
del subject_id
# Del global vars
del subject_ids_to_test
del filter_lower_bound
del filter_upper_bound
del baseline
del do_experiment
del end_offset
del start_offset
| Subject | CSP + RF: cross validation accuracy | CSP + RF: test split accuracy | Config |
|---|---|---|---|
| B | 0.6588 +- 0.0316 | 0.6042 | 4 CSP components | RF max depth 10, max features 0.4, min sample split 10, 50 estimators |
| C | 0.7119 +- 0.0316 | 0.7031 | 6 CSP components | RF max depth 3, max features 0.4, min sample split 5, 250 estimators |
| E | 0.7251 +- 0.0176 | 0.7539 | 10 CSP components | RF max depth None, max features 0.2, min sample split 2, 250 estimators |
####################################################
# GRID SEARCH RESULTS
####################################################
# Configure global parameters for all experiments
subject_ids_to_test = ["B", "C", "E"] # Subjects with three recordings
# Loop over all found results
for subject_id in subject_ids_to_test:
print("\n\n")
print("####################################################")
print(f"# GRID SEARCH RESULTS FOR SUBJECT {subject_id}")
print("####################################################")
print("\n\n")
# Open from file
with open(f"saved_variables/2/samesubject_samesession/subject{subject_id}/gridsearch_csprf_subject{subject_id}.pickle", 'rb') as f:
grid_search = pickle.load(f)
# Print the results
print(f"Best estimator has accuracy of {np.round(grid_search.best_score_, 4)} +- {np.round(grid_search.cv_results_['std_test_score'][grid_search.best_index_], 4)} with the following parameters")
print(grid_search.best_params_)
# Get grid search results
grid_search_results = pd.DataFrame(grid_search.cv_results_)
# Keep relevant columns and sort on rank
grid_search_results.drop(labels='params', axis=1, inplace= True)
grid_search_results.sort_values(by=['rank_test_score'], inplace=True)
# Display grid search resulst
print("\n\n Top 10 grid search results: ")
display(grid_search_results.head(10))
print("\n\n Worst 10 grid search results: ")
display(grid_search_results.tail(10))
# Display some statistics
print(f"\n\nIn total there are {len(grid_search_results)} different configurations tested.")
max_score = grid_search_results['mean_test_score'].max()
print(f"The best mean test score is {round(max_score, 4)}")
shared_first_place_count = len(grid_search_results[grid_search_results['mean_test_score'].between(max_score, max_score)])
print(f"There are {shared_first_place_count} configurations with this maximum score")
close_first_place_count = len(grid_search_results[grid_search_results['mean_test_score'].between(max_score-0.02, max_score)])
print(f"There are {close_first_place_count} configurations within 0.02 of this maximum score")
# Display statistics for best classifiers
print("\n\nThe describe of the configurations within 0.02 of this maximum score is as follows:")
display(grid_search_results[grid_search_results['mean_test_score'].between(max_score-0.02, max_score)].describe(include="all"))
# Remove unsused variables
del f
del grid_search
del max_score
del shared_first_place_count
del close_first_place_count
del grid_search_results
del subject_ids_to_test
del subject_id
####################################################
# GRID SEARCH RESULTS FOR SUBJECT B
####################################################
Best estimator has accuracy of 0.6589 +- 0.03 with the following parameters
{'CSP__n_components': 4, 'RF__max_depth': 10, 'RF__max_features': 0.4, 'RF__min_samples_split': 10, 'RF__n_estimators': 50}
Top 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_RF__max_depth | param_RF__max_features | param_RF__min_samples_split | param_RF__n_estimators | split0_test_score | ... | split3_test_score | mean_test_score | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 254 | 2.654652 | 0.033919 | 0.053983 | 0.000707 | 4 | 10 | 0.4 | 10 | 500 | 0.604167 | ... | 0.677083 | 0.658854 | 0.031574 | 1 | 0.862847 | 0.880208 | 0.838542 | 0.848958 | 0.857639 | 0.015625 |
| 251 | 2.073587 | 0.018533 | 0.007248 | 0.000433 | 4 | 10 | 0.4 | 10 | 50 | 0.609375 | ... | 0.661458 | 0.658854 | 0.030033 | 1 | 0.873264 | 0.869792 | 0.840278 | 0.848958 | 0.858073 | 0.013855 |
| 164 | 2.576927 | 0.033308 | 0.047735 | 0.000433 | 4 | 3 | 0.4 | 10 | 500 | 0.609375 | ... | 0.677083 | 0.657552 | 0.028852 | 3 | 0.718750 | 0.703125 | 0.701389 | 0.699653 | 0.705729 | 0.007617 |
| 371 | 2.060092 | 0.028208 | 0.007248 | 0.000433 | 6 | 3 | sqrt | 10 | 50 | 0.614583 | ... | 0.661458 | 0.656250 | 0.025248 | 4 | 0.720486 | 0.710069 | 0.723958 | 0.706597 | 0.715278 | 0.007158 |
| 237 | 2.134069 | 0.034848 | 0.012996 | 0.000707 | 4 | 10 | 0.2 | 10 | 100 | 0.609375 | ... | 0.677083 | 0.654948 | 0.026906 | 5 | 0.868056 | 0.868056 | 0.847222 | 0.835069 | 0.854601 | 0.014124 |
| 74 | 2.688141 | 0.042636 | 0.057982 | 0.004182 | 4 | None | 0.4 | 10 | 500 | 0.609375 | ... | 0.677083 | 0.653646 | 0.026685 | 6 | 0.892361 | 0.902778 | 0.890625 | 0.883681 | 0.892361 | 0.006835 |
| 235 | 2.036850 | 0.035629 | 0.003249 | 0.000433 | 4 | 10 | 0.2 | 10 | 10 | 0.645833 | ... | 0.651042 | 0.653646 | 0.010737 | 6 | 0.833333 | 0.857639 | 0.833333 | 0.807292 | 0.832899 | 0.017806 |
| 246 | 2.056843 | 0.037287 | 0.007748 | 0.000433 | 4 | 10 | 0.4 | 5 | 50 | 0.614583 | ... | 0.661458 | 0.653646 | 0.023438 | 6 | 0.921875 | 0.918403 | 0.906250 | 0.911458 | 0.914497 | 0.006061 |
| 238 | 2.334503 | 0.025216 | 0.027991 | 0.000707 | 4 | 10 | 0.2 | 10 | 250 | 0.609375 | ... | 0.671875 | 0.653646 | 0.025911 | 6 | 0.857639 | 0.875000 | 0.847222 | 0.843750 | 0.855903 | 0.012153 |
| 58 | 2.335754 | 0.022592 | 0.028241 | 0.000433 | 4 | None | 0.2 | 10 | 250 | 0.604167 | ... | 0.671875 | 0.652344 | 0.028852 | 10 | 0.888889 | 0.904514 | 0.881944 | 0.881944 | 0.889323 | 0.009217 |
10 rows × 22 columns
Worst 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_RF__max_depth | param_RF__max_features | param_RF__min_samples_split | param_RF__n_estimators | split0_test_score | ... | split3_test_score | mean_test_score | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 305 | 1.995363 | 0.034867 | 0.0 | 0.0 | 6 | None | None | 5 | 10 | NaN | ... | NaN | NaN | NaN | 801 | NaN | NaN | NaN | NaN | NaN | NaN |
| 482 | 2.036600 | 0.026072 | 0.0 | 0.0 | 6 | 10 | None | 2 | 100 | NaN | ... | NaN | NaN | NaN | 802 | NaN | NaN | NaN | NaN | NaN | NaN |
| 480 | 2.010108 | 0.026329 | 0.0 | 0.0 | 6 | 10 | None | 2 | 10 | NaN | ... | NaN | NaN | NaN | 803 | NaN | NaN | NaN | NaN | NaN | NaN |
| 403 | 2.078086 | 0.025035 | 0.0 | 0.0 | 6 | 3 | None | 10 | 250 | NaN | ... | NaN | NaN | NaN | 804 | NaN | NaN | NaN | NaN | NaN | NaN |
| 402 | 2.044847 | 0.025236 | 0.0 | 0.0 | 6 | 3 | None | 10 | 100 | NaN | ... | NaN | NaN | NaN | 805 | NaN | NaN | NaN | NaN | NaN | NaN |
| 401 | 2.024104 | 0.045518 | 0.0 | 0.0 | 6 | 3 | None | 10 | 50 | NaN | ... | NaN | NaN | NaN | 806 | NaN | NaN | NaN | NaN | NaN | NaN |
| 400 | 1.992863 | 0.024867 | 0.0 | 0.0 | 6 | 3 | None | 10 | 10 | NaN | ... | NaN | NaN | NaN | 807 | NaN | NaN | NaN | NaN | NaN | NaN |
| 399 | 2.153062 | 0.035509 | 0.0 | 0.0 | 6 | 3 | None | 5 | 500 | NaN | ... | NaN | NaN | NaN | 808 | NaN | NaN | NaN | NaN | NaN | NaN |
| 481 | 2.021355 | 0.038204 | 0.0 | 0.0 | 6 | 10 | None | 2 | 50 | NaN | ... | NaN | NaN | NaN | 809 | NaN | NaN | NaN | NaN | NaN | NaN |
| 404 | 2.149314 | 0.023689 | 0.0 | 0.0 | 6 | 3 | None | 10 | 500 | NaN | ... | NaN | NaN | NaN | 810 | NaN | NaN | NaN | NaN | NaN | NaN |
10 rows × 22 columns
In total there are 810 different configurations tested. The best mean test score is 0.6589 There are 2 configurations with this maximum score There are 199 configurations within 0.02 of this maximum score The describe of the configurations within 0.02 of this maximum score is as follows:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_RF__max_depth | param_RF__max_features | param_RF__min_samples_split | param_RF__n_estimators | split0_test_score | ... | split3_test_score | mean_test_score | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 199.000000 | 199.000000 | 199.000000 | 199.000000 | 199.0 | 167.0 | 199.0 | 199.0 | 199.0 | 199.000000 | ... | 199.000000 | 199.000000 | 199.000000 | 199.000000 | 199.000000 | 199.000000 | 199.000000 | 199.000000 | 199.000000 | 199.000000 |
| unique | NaN | NaN | NaN | NaN | 3.0 | 2.0 | 5.0 | 3.0 | 5.0 | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| top | NaN | NaN | NaN | NaN | 4.0 | 3.0 | 0.2 | 10.0 | 500.0 | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| freq | NaN | NaN | NaN | NaN | 93.0 | 110.0 | 55.0 | 82.0 | 55.0 | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| mean | 2.333009 | 0.032239 | 0.026223 | 0.001333 | NaN | NaN | NaN | NaN | NaN | 0.596289 | ... | 0.667530 | 0.644047 | 0.031271 | 93.678392 | 0.812177 | 0.812168 | 0.801656 | 0.803575 | 0.807394 | 0.007660 |
| std | 0.255482 | 0.014742 | 0.018601 | 0.001820 | NaN | NaN | NaN | NaN | NaN | 0.017572 | ... | 0.012116 | 0.004217 | 0.008114 | 55.111064 | 0.115280 | 0.112937 | 0.113609 | 0.115944 | 0.114227 | 0.003901 |
| min | 1.997112 | 0.011019 | 0.002749 | 0.000000 | NaN | NaN | NaN | NaN | NaN | 0.557292 | ... | 0.640625 | 0.639323 | 0.010737 | 1.000000 | 0.685764 | 0.682292 | 0.678819 | 0.682292 | 0.685764 | 0.000000 |
| 25% | 2.120322 | 0.025338 | 0.011247 | 0.000433 | NaN | NaN | NaN | NaN | NaN | 0.583333 | ... | 0.661458 | 0.640625 | 0.026236 | 48.000000 | 0.710069 | 0.709201 | 0.701389 | 0.701389 | 0.705295 | 0.005015 |
| 50% | 2.304763 | 0.031033 | 0.025242 | 0.000707 | NaN | NaN | NaN | NaN | NaN | 0.598958 | ... | 0.666667 | 0.643229 | 0.030230 | 95.000000 | 0.730903 | 0.743056 | 0.729167 | 0.718750 | 0.730903 | 0.006780 |
| 75% | 2.595921 | 0.036361 | 0.048860 | 0.001446 | NaN | NaN | NaN | NaN | NaN | 0.609375 | ... | 0.677083 | 0.645833 | 0.036435 | 143.000000 | 0.923611 | 0.917535 | 0.900174 | 0.912326 | 0.914062 | 0.010042 |
| max | 2.885578 | 0.202313 | 0.061980 | 0.012632 | NaN | NaN | NaN | NaN | NaN | 0.645833 | ... | 0.697917 | 0.658854 | 0.050898 | 176.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 0.021351 |
11 rows × 22 columns
####################################################
# GRID SEARCH RESULTS FOR SUBJECT C
####################################################
Best estimator has accuracy of 0.7119 +- 0.0316 with the following parameters
{'CSP__n_components': 6, 'RF__max_depth': 3, 'RF__max_features': 0.4, 'RF__min_samples_split': 5, 'RF__n_estimators': 250}
Top 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_RF__max_depth | param_RF__max_features | param_RF__min_samples_split | param_RF__n_estimators | split0_test_score | ... | split3_test_score | mean_test_score | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 428 | 2.420227 | 0.044869 | 0.026242 | 0.001089 | 6 | 3 | 0.4 | 5 | 250 | 0.739583 | ... | 0.743455 | 0.711906 | 0.031637 | 1 | 0.766957 | 0.746087 | 0.782609 | 0.760417 | 0.764017 | 0.013122 |
| 705 | 2.091582 | 0.046924 | 0.003999 | 0.000707 | 10 | 3 | 0.6 | 2 | 10 | 0.760417 | ... | 0.732984 | 0.711892 | 0.037786 | 2 | 0.780870 | 0.730435 | 0.784348 | 0.751736 | 0.761847 | 0.022120 |
| 324 | 2.754620 | 0.045933 | 0.055482 | 0.000500 | 6 | None | 0.2 | 5 | 500 | 0.750000 | ... | 0.732984 | 0.711892 | 0.035376 | 2 | 0.994783 | 0.986087 | 0.987826 | 0.984375 | 0.988268 | 0.003954 |
| 598 | 2.442470 | 0.063389 | 0.029990 | 0.000707 | 10 | None | 0.2 | 10 | 250 | 0.739583 | ... | 0.717277 | 0.711871 | 0.036281 | 4 | 0.958261 | 0.965217 | 0.944348 | 0.949653 | 0.954370 | 0.007992 |
| 462 | 2.228788 | 0.056817 | 0.014245 | 0.001298 | 6 | 10 | sqrt | 10 | 100 | 0.760417 | ... | 0.717277 | 0.710569 | 0.032339 | 5 | 0.904348 | 0.909565 | 0.886957 | 0.894097 | 0.898742 | 0.008790 |
| 723 | 2.491953 | 0.055635 | 0.030490 | 0.001118 | 10 | 10 | sqrt | 2 | 250 | 0.744792 | ... | 0.712042 | 0.710563 | 0.033428 | 6 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 0.000000 |
| 479 | 2.836594 | 0.049972 | 0.054733 | 0.001299 | 6 | 10 | log2 | 10 | 500 | 0.739583 | ... | 0.738220 | 0.709295 | 0.032295 | 7 | 0.914783 | 0.913043 | 0.886957 | 0.887153 | 0.900484 | 0.013443 |
| 298 | 2.451467 | 0.054164 | 0.030240 | 0.002585 | 6 | None | log2 | 10 | 250 | 0.744792 | ... | 0.722513 | 0.709274 | 0.028679 | 8 | 0.918261 | 0.926957 | 0.897391 | 0.899306 | 0.910479 | 0.012532 |
| 733 | 2.511947 | 0.051569 | 0.029991 | 0.000707 | 10 | 10 | sqrt | 10 | 250 | 0.739583 | ... | 0.712042 | 0.709260 | 0.035169 | 9 | 0.953043 | 0.951304 | 0.930435 | 0.949653 | 0.946109 | 0.009129 |
| 432 | 2.186051 | 0.051005 | 0.011747 | 0.000829 | 6 | 3 | 0.4 | 10 | 100 | 0.744792 | ... | 0.743455 | 0.707999 | 0.036315 | 10 | 0.779130 | 0.740870 | 0.768696 | 0.765625 | 0.763580 | 0.014035 |
10 rows × 22 columns
Worst 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_RF__max_depth | param_RF__max_features | param_RF__min_samples_split | param_RF__n_estimators | split0_test_score | ... | split3_test_score | mean_test_score | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 673 | 2.136317 | 0.049244 | 0.0 | 0.0 | 10 | 3 | None | 10 | 250 | NaN | ... | NaN | NaN | NaN | 801 | NaN | NaN | NaN | NaN | NaN | NaN |
| 482 | 2.109576 | 0.068961 | 0.0 | 0.0 | 6 | 10 | None | 2 | 100 | NaN | ... | NaN | NaN | NaN | 802 | NaN | NaN | NaN | NaN | NaN | NaN |
| 480 | 2.072587 | 0.051152 | 0.0 | 0.0 | 6 | 10 | None | 2 | 10 | NaN | ... | NaN | NaN | NaN | 803 | NaN | NaN | NaN | NaN | NaN | NaN |
| 403 | 2.174305 | 0.048053 | 0.0 | 0.0 | 6 | 3 | None | 10 | 250 | NaN | ... | NaN | NaN | NaN | 804 | NaN | NaN | NaN | NaN | NaN | NaN |
| 402 | 2.094081 | 0.053492 | 0.0 | 0.0 | 6 | 3 | None | 10 | 100 | NaN | ... | NaN | NaN | NaN | 805 | NaN | NaN | NaN | NaN | NaN | NaN |
| 401 | 2.092332 | 0.057013 | 0.0 | 0.0 | 6 | 3 | None | 10 | 50 | NaN | ... | NaN | NaN | NaN | 806 | NaN | NaN | NaN | NaN | NaN | NaN |
| 400 | 2.073837 | 0.052130 | 0.0 | 0.0 | 6 | 3 | None | 10 | 10 | NaN | ... | NaN | NaN | NaN | 807 | NaN | NaN | NaN | NaN | NaN | NaN |
| 399 | 2.217541 | 0.066166 | 0.0 | 0.0 | 6 | 3 | None | 5 | 500 | NaN | ... | NaN | NaN | NaN | 808 | NaN | NaN | NaN | NaN | NaN | NaN |
| 481 | 2.069839 | 0.068613 | 0.0 | 0.0 | 6 | 10 | None | 2 | 50 | NaN | ... | NaN | NaN | NaN | 809 | NaN | NaN | NaN | NaN | NaN | NaN |
| 404 | 2.211793 | 0.043371 | 0.0 | 0.0 | 6 | 3 | None | 10 | 500 | NaN | ... | NaN | NaN | NaN | 810 | NaN | NaN | NaN | NaN | NaN | NaN |
10 rows × 22 columns
In total there are 810 different configurations tested. The best mean test score is 0.7119 There are 1 configurations with this maximum score There are 288 configurations within 0.02 of this maximum score The describe of the configurations within 0.02 of this maximum score is as follows:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_RF__max_depth | param_RF__max_features | param_RF__min_samples_split | param_RF__n_estimators | split0_test_score | ... | split3_test_score | mean_test_score | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 288.000000 | 288.000000 | 288.000000 | 2.880000e+02 | 288.0 | 200.0 | 288 | 288.0 | 288.0 | 288.000000 | ... | 288.000000 | 288.000000 | 288.000000 | 288.000000 | 288.000000 | 288.000000 | 288.000000 | 288.000000 | 288.000000 | 288.000000 |
| unique | NaN | NaN | NaN | NaN | 2.0 | 2.0 | 5 | 3.0 | 5.0 | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| top | NaN | NaN | NaN | NaN | 10.0 | 3.0 | log2 | 10.0 | 500.0 | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| freq | NaN | NaN | NaN | NaN | 153.0 | 106.0 | 66 | 115.0 | 77.0 | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| mean | 2.427031 | 0.053353 | 0.026555 | 1.405714e-03 | NaN | NaN | NaN | NaN | NaN | 0.731210 | ... | 0.716514 | 0.698868 | 0.034036 | 143.774306 | 0.896455 | 0.887530 | 0.893430 | 0.890999 | 0.892104 | 0.008254 |
| std | 0.285123 | 0.007998 | 0.018788 | 1.499261e-03 | NaN | NaN | NaN | NaN | NaN | 0.014024 | ... | 0.016005 | 0.004607 | 0.006704 | 83.130650 | 0.096171 | 0.107630 | 0.090088 | 0.098688 | 0.097913 | 0.005241 |
| min | 2.073588 | 0.024745 | 0.003249 | 1.032383e-07 | NaN | NaN | NaN | NaN | NaN | 0.677083 | ... | 0.670157 | 0.692299 | 0.011070 | 1.000000 | 0.739130 | 0.720000 | 0.761739 | 0.744792 | 0.747499 | 0.000000 |
| 25% | 2.182178 | 0.048045 | 0.009247 | 5.000681e-04 | NaN | NaN | NaN | NaN | NaN | 0.723958 | ... | 0.701571 | 0.694949 | 0.030145 | 72.250000 | 0.780870 | 0.754783 | 0.786087 | 0.769097 | 0.772272 | 0.003136 |
| 50% | 2.373992 | 0.053122 | 0.025367 | 8.473666e-04 | NaN | NaN | NaN | NaN | NaN | 0.734375 | ... | 0.717277 | 0.698834 | 0.033805 | 144.000000 | 0.933913 | 0.931304 | 0.914783 | 0.929688 | 0.932420 | 0.008717 |
| 75% | 2.699762 | 0.058801 | 0.048984 | 1.643769e-03 | NaN | NaN | NaN | NaN | NaN | 0.739583 | ... | 0.729058 | 0.701477 | 0.037834 | 212.500000 | 0.990000 | 0.987826 | 0.987826 | 0.987847 | 0.988266 | 0.012163 |
| max | 3.210224 | 0.074814 | 0.065480 | 1.386146e-02 | NaN | NaN | NaN | NaN | NaN | 0.765625 | ... | 0.759162 | 0.711906 | 0.054123 | 288.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 0.025550 |
11 rows × 22 columns
####################################################
# GRID SEARCH RESULTS FOR SUBJECT E
####################################################
Best estimator has accuracy of 0.7251 +- 0.0176 with the following parameters
{'CSP__n_components': 10, 'RF__max_depth': None, 'RF__max_features': 0.2, 'RF__min_samples_split': 2, 'RF__n_estimators': 250}
Top 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_RF__max_depth | param_RF__max_features | param_RF__min_samples_split | param_RF__n_estimators | split0_test_score | ... | split3_test_score | mean_test_score | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 588 | 2.500701 | 0.053585 | 0.029991 | 0.000707 | 10 | None | 0.2 | 2 | 250 | 0.696335 | ... | 0.727749 | 0.725131 | 0.017561 | 1 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 0.000000 |
| 594 | 2.907321 | 0.054588 | 0.056482 | 0.001118 | 10 | None | 0.2 | 5 | 500 | 0.706806 | ... | 0.738220 | 0.723822 | 0.012486 | 2 | 0.998255 | 1.000000 | 1.000000 | 1.000000 | 0.999564 | 0.000756 |
| 557 | 2.284521 | 0.057858 | 0.014495 | 0.000500 | 10 | None | log2 | 2 | 100 | 0.685864 | ... | 0.748691 | 0.718586 | 0.022328 | 3 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 0.000000 |
| 772 | 2.275023 | 0.057715 | 0.014495 | 0.000866 | 10 | 10 | 0.2 | 5 | 100 | 0.696335 | ... | 0.717277 | 0.717277 | 0.019590 | 4 | 0.991274 | 0.998255 | 0.989529 | 0.993019 | 0.993019 | 0.003265 |
| 742 | 2.300265 | 0.060919 | 0.013496 | 0.000500 | 10 | 10 | log2 | 5 | 100 | 0.691099 | ... | 0.712042 | 0.715969 | 0.016296 | 5 | 0.994764 | 0.993019 | 0.993019 | 0.993019 | 0.993455 | 0.000756 |
| 549 | 2.990045 | 0.063808 | 0.060981 | 0.007033 | 10 | None | sqrt | 5 | 500 | 0.691099 | ... | 0.712042 | 0.715969 | 0.023231 | 5 | 0.998255 | 0.998255 | 1.000000 | 1.000000 | 0.999127 | 0.000873 |
| 777 | 2.315011 | 0.051216 | 0.014495 | 0.000866 | 10 | 10 | 0.2 | 10 | 100 | 0.696335 | ... | 0.727749 | 0.715969 | 0.011925 | 5 | 0.959860 | 0.956370 | 0.949389 | 0.958115 | 0.955934 | 0.003975 |
| 587 | 2.265276 | 0.052534 | 0.015495 | 0.001118 | 10 | None | 0.2 | 2 | 100 | 0.675393 | ... | 0.738220 | 0.715969 | 0.027299 | 5 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 0.000000 |
| 776 | 2.204795 | 0.062592 | 0.008497 | 0.000500 | 10 | 10 | 0.2 | 10 | 50 | 0.685864 | ... | 0.743455 | 0.715969 | 0.021706 | 5 | 0.947644 | 0.945899 | 0.951134 | 0.959860 | 0.951134 | 0.005379 |
| 743 | 2.543687 | 0.036337 | 0.029741 | 0.000433 | 10 | 10 | log2 | 5 | 250 | 0.701571 | ... | 0.743455 | 0.714660 | 0.024972 | 10 | 0.991274 | 0.994764 | 0.993019 | 0.998255 | 0.994328 | 0.002581 |
10 rows × 22 columns
Worst 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_RF__max_depth | param_RF__max_features | param_RF__min_samples_split | param_RF__n_estimators | split0_test_score | ... | split3_test_score | mean_test_score | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 489 | 2.259528 | 0.060900 | 0.0 | 0.0 | 6 | 10 | None | 5 | 500 | NaN | ... | NaN | NaN | NaN | 801 | NaN | NaN | NaN | NaN | NaN | NaN |
| 481 | 2.116824 | 0.042179 | 0.0 | 0.0 | 6 | 10 | None | 2 | 50 | NaN | ... | NaN | NaN | NaN | 802 | NaN | NaN | NaN | NaN | NaN | NaN |
| 403 | 2.197548 | 0.072033 | 0.0 | 0.0 | 6 | 3 | None | 10 | 250 | NaN | ... | NaN | NaN | NaN | 803 | NaN | NaN | NaN | NaN | NaN | NaN |
| 402 | 2.147314 | 0.059015 | 0.0 | 0.0 | 6 | 3 | None | 10 | 100 | NaN | ... | NaN | NaN | NaN | 804 | NaN | NaN | NaN | NaN | NaN | NaN |
| 401 | 2.122823 | 0.059150 | 0.0 | 0.0 | 6 | 3 | None | 10 | 50 | NaN | ... | NaN | NaN | NaN | 805 | NaN | NaN | NaN | NaN | NaN | NaN |
| 400 | 2.103827 | 0.058639 | 0.0 | 0.0 | 6 | 3 | None | 10 | 10 | NaN | ... | NaN | NaN | NaN | 806 | NaN | NaN | NaN | NaN | NaN | NaN |
| 399 | 2.253530 | 0.046500 | 0.0 | 0.0 | 6 | 3 | None | 5 | 500 | NaN | ... | NaN | NaN | NaN | 807 | NaN | NaN | NaN | NaN | NaN | NaN |
| 398 | 2.184302 | 0.059028 | 0.0 | 0.0 | 6 | 3 | None | 5 | 250 | NaN | ... | NaN | NaN | NaN | 808 | NaN | NaN | NaN | NaN | NaN | NaN |
| 480 | 2.097080 | 0.051575 | 0.0 | 0.0 | 6 | 10 | None | 2 | 10 | NaN | ... | NaN | NaN | NaN | 809 | NaN | NaN | NaN | NaN | NaN | NaN |
| 404 | 2.272524 | 0.062099 | 0.0 | 0.0 | 6 | 3 | None | 10 | 500 | NaN | ... | NaN | NaN | NaN | 810 | NaN | NaN | NaN | NaN | NaN | NaN |
10 rows × 22 columns
In total there are 810 different configurations tested. The best mean test score is 0.7251 There are 1 configurations with this maximum score There are 61 configurations within 0.02 of this maximum score The describe of the configurations within 0.02 of this maximum score is as follows:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_RF__max_depth | param_RF__max_features | param_RF__min_samples_split | param_RF__n_estimators | split0_test_score | ... | split3_test_score | mean_test_score | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 61.000000 | 61.000000 | 61.000000 | 6.100000e+01 | 61.0 | 33.0 | 61.0 | 61.0 | 61.0 | 61.000000 | ... | 61.000000 | 61.000000 | 61.000000 | 61.000000 | 61.000000 | 61.000000 | 61.000000 | 61.000000 | 61.000000 | 61.000000 |
| unique | NaN | NaN | NaN | NaN | 1.0 | 2.0 | 5.0 | 3.0 | 4.0 | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| top | NaN | NaN | NaN | NaN | 10.0 | 10.0 | 0.2 | 5.0 | 500.0 | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| freq | NaN | NaN | NaN | NaN | 61.0 | 32.0 | 23.0 | 21.0 | 21.0 | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| mean | 2.599846 | 0.058300 | 0.032809 | 1.156652e-03 | NaN | NaN | NaN | NaN | NaN | 0.688868 | ... | 0.721483 | 0.710819 | 0.018027 | 29.245902 | 0.984064 | 0.984007 | 0.982147 | 0.985180 | 0.983850 | 0.002078 |
| std | 0.337803 | 0.012892 | 0.019007 | 1.223278e-03 | NaN | NaN | NaN | NaN | NaN | 0.010227 | ... | 0.012675 | 0.004294 | 0.005382 | 17.376858 | 0.027420 | 0.028592 | 0.030035 | 0.026266 | 0.027985 | 0.002000 |
| min | 2.178803 | 0.032044 | 0.008497 | 1.032383e-07 | NaN | NaN | NaN | NaN | NaN | 0.664921 | ... | 0.691099 | 0.705497 | 0.007744 | 1.000000 | 0.806283 | 0.804538 | 0.795812 | 0.821990 | 0.807155 | 0.000000 |
| 25% | 2.286519 | 0.051035 | 0.014495 | 4.996658e-04 | NaN | NaN | NaN | NaN | NaN | 0.680628 | ... | 0.712042 | 0.708115 | 0.013539 | 15.000000 | 0.970332 | 0.966841 | 0.963351 | 0.970332 | 0.967714 | 0.000756 |
| 50% | 2.537939 | 0.057715 | 0.029991 | 8.289706e-04 | NaN | NaN | NaN | NaN | NaN | 0.691099 | ... | 0.722513 | 0.710733 | 0.019370 | 28.000000 | 0.994764 | 0.996510 | 0.993019 | 0.996510 | 0.994764 | 0.001511 |
| 75% | 2.926564 | 0.063010 | 0.055733 | 1.298726e-03 | NaN | NaN | NaN | NaN | NaN | 0.696335 | ... | 0.727749 | 0.713351 | 0.021706 | 40.000000 | 0.998255 | 1.000000 | 1.000000 | 1.000000 | 0.999564 | 0.003265 |
| max | 3.325688 | 0.111701 | 0.060981 | 7.033335e-03 | NaN | NaN | NaN | NaN | NaN | 0.706806 | ... | 0.748691 | 0.725131 | 0.032354 | 56.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 0.009439 |
11 rows × 22 columns
####################################################
# TEST RESULTS FOR BEST FOUND GRID SEARCH
####################################################
# Configure global parameters for all experiments
subject_ids_to_test = ["B", "C", "E"] # Subjects with three recordings
best_found_csp_components = [4, 6 , 10]
best_found_rf_depth = [10, 3, None]
best_found_rf_max_featues = [0.4, 0.4, 0.2]
best_found_rf_min_sample = [10, 5, 2]
best_found_rf_n_estimators = [50, 250, 250]
# Loop over all found results
for i in range(len(subject_ids_to_test)):
print("\n\n")
print("####################################################")
print(f"# TEST RESULTS FOR SUBJECT {subject_ids_to_test[i]}")
print("####################################################")
print("\n\n")
# Open train and test data from file
with open(f"saved_variables/2/samesubject_samesession/subject{subject_ids_to_test[i]}/testdata-x_csprf_subject{subject_ids_to_test[i]}.pickle", 'rb') as f:
X_test = pickle.load(f)
with open(f"saved_variables/2/samesubject_samesession/subject{subject_ids_to_test[i]}/testdata-y_csprf_subject{subject_ids_to_test[i]}.pickle", 'rb') as f:
y_test = pickle.load(f)
with open(f"saved_variables/2/samesubject_samesession/subject{subject_ids_to_test[i]}/traindata-x_csprf_subject{subject_ids_to_test[i]}.pickle", 'rb') as f:
X_train = pickle.load(f)
with open(f"saved_variables/2/samesubject_samesession/subject{subject_ids_to_test[i]}/traindata-y_csprf_subject{subject_ids_to_test[i]}.pickle", 'rb') as f:
y_train = pickle.load(f)
# Make the classifier
csp = CSP(norm_trace=False,
component_order="mutual_info",
cov_est= "epoch",
n_components= best_found_csp_components[i])
rf = RandomForestClassifier(bootstrap= True,
criterion= "gini",
max_depth= best_found_rf_depth[i],
max_features= best_found_rf_max_featues[i],
min_samples_split= best_found_rf_min_sample[i],
n_estimators= best_found_rf_n_estimators[i])
# Configure the pipeline
pipeline = Pipeline([('CSP', csp), ('RF', rf)])
# Fit the pipeline
with io.capture_output():
pipeline.fit(X_train, y_train)
# Get accuracy for single fit
y_pred = pipeline.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
# Print accuracy results and CM
print(f"Test accuracy for subject {subject_ids_to_test[i]}: {np.round(accuracy, 4)}")
ConfusionMatrixDisplay.from_predictions(y_true= y_test, y_pred= y_pred)
plt.show()
# plot CSP patterns estimated on train data for visualization
pipeline['CSP'].plot_patterns(CLA_dataset.get_last_raw_mne_data_for_subject(subject_id= subject_ids_to_test[i]).info, ch_type='eeg', units='Patterns (AU)', size=1.5)
plt.show()
# Remove unsused variables
del subject_ids_to_test
del best_found_csp_components
del best_found_rf_depth
del best_found_rf_max_featues
del best_found_rf_min_sample
del best_found_rf_n_estimators
del i
del f
del X_test
del y_test
del X_train
del y_train
del csp
del rf
del pipeline
del y_pred
del accuracy
#################################################### # TEST RESULTS FOR SUBJECT B #################################################### Test accuracy for subject B: 0.5938
Reading 0 ... 667799 = 0.000 ... 3338.995 secs...
#################################################### # TEST RESULTS FOR SUBJECT C #################################################### Test accuracy for subject C: 0.7083
Reading 0 ... 669399 = 0.000 ... 3346.995 secs...
#################################################### # TEST RESULTS FOR SUBJECT E #################################################### Test accuracy for subject E: 0.733
Reading 0 ... 666999 = 0.000 ... 3334.995 secs...
As discussed in the master's thesis, training and testing a classification system can happen using multiple strategies. A classifier may be trained on a singular subject, but by using one or more sessions for training and testing on a new, unseen session. This is a harder task than the previous one, where training and testing were done for the same session. This section will train the same classifiers for the same participants as before but by using the first two datasets as training data and the third and final session of each participant as a standalone test set which is not used in training.
This experiment works as follows:
####################################################
# GRID SEARCHING BEST PIPELINE FOR EACH SUBJECT
####################################################
# Configure global parameters for all experiments
subject_ids_to_test = ["B", "C", "E"] # Subjects with three recordings
start_offset = -1 # One second before visual queue
end_offset = 1 # One second after visual queue
baseline = (None, 0) # Baseline correction using data before the visual queue
filter_lower_bound = 2 # Filter out any frequency below this
filter_upper_bound = 32 # Filter out any frequency above this
do_experiment = False # Long experiment disabled per default
if do_experiment:
# Loop over all subjects and perform the grid search for finding the best parameters
for subject_id in subject_ids_to_test:
# Get all training data (all but last session of participant)
mne_raws= CLA_dataset.get_all_but_last_raw_mne_data_for_subject(subject_id= subject_id)
# Combine training data into singular mne raw
mne_raw = mne.concatenate_raws(mne_raws)
# Delete all raws since concat changes them
del mne_raws
# Get epochs for all those MNE raws (all training sessions)
mne_epochs = CLA_dataset.get_usefull_epochs_from_raw(mne_raw,
start_offset= start_offset,
end_offset= end_offset,
baseline= baseline)
# Only keep epochs from the MI tasks
mne_epochs = mne_epochs['task/neutral', 'task/left', 'task/right']
# Load epochs into memory
mne_epochs.load_data()
# Get the labels
labels = mne_epochs.events[:, -1]
# Use a fixed filter
mne_epochs.filter(l_freq= filter_lower_bound,
h_freq= filter_upper_bound,
picks= "all",
phase= "minimum",
fir_window= "blackman",
fir_design= "firwin",
pad= 'median',
n_jobs= -1,
verbose= False)
# Get a half second window
mne_epochs_data = mne_epochs.get_data(tmin= 0.1, tmax= 0.6)
# Configure the pipeline components by specifying the default parameters
csp = CSP(norm_trace=False,
component_order="mutual_info",
cov_est= "epoch")
lda = LinearDiscriminantAnalysis(shrinkage= None,
priors=[1/3, 1/3, 1/3])
# Configure the pipeline
pipeline = Pipeline([('CSP', csp), ('LDA', lda)])
# Configure cross validation to use, more splits then before since we have more data
cv = StratifiedKFold(n_splits= 6,
shuffle= True,
random_state= 2022)
# Configure the hyperparameters to test
# NOTE: these are somewhat limited due to limitedd computational resources
param_grid = [{"CSP__n_components": [2, 3, 4, 6, 10],
"LDA__solver": ["svd"],
"LDA__tol": [0.0001, 0.00001, 0.001, 0.0004, 0.00007]
},
{"CSP__n_components": [2, 3, 4, 6, 10],
"LDA__solver": ["lsqr" , "eigen"]
}]
# Configure the grid search
grid_search = GridSearchCV(estimator= pipeline,
param_grid= param_grid,
scoring= "accuracy",
n_jobs= -1,
refit= False, # We will do this manually
cv= cv,
verbose= 10,
return_train_score= True)
# Do the grid search on the training data
grid_search.fit(X= mne_epochs_data, y= labels)
# Store the results of the grid search
with open(f"saved_variables/2/samesubject_differentsession/subject{subject_id}/gridsearch_csplda.pickle", 'wb') as file:
pickle.dump(grid_search, file)
# Delete vars after singular experiment
del mne_raw
del mne_epochs
del mne_epochs_data
del csp
del lda
del pipeline
del labels
del cv
del file
del grid_search
del param_grid
# Delete vars after all experiments
del subject_id
# Del global vars
del subject_ids_to_test
del filter_lower_bound
del filter_upper_bound
del baseline
del do_experiment
del end_offset
del start_offset
The CV results are based on the training set alone and thus only look at the first two sessions. The test result is for a new, unseen session and thus scores are expected to differ.
| Subject | CSP + LDA: cross validation accuracy | CSP + LDA: test split accuracy | Config |
|---|---|---|---|
| B | 0.4500 +- 0.02576 | 0.4677 | CSP 10 components | SVD LDA with 0.0001 tol |
| C | 0.8177 +- 0.01940 | 0.3587 | CSP 10 components | SVD LDA with 0.0001 tol |
| E | 0.5525 +- 0.03678 | 0.5518 | CSP 10 components | SVD LDA with 0.0001 tol |
####################################################
# GRID SEARCH RESULTS
####################################################
# Configure global parameters for all experiments
subject_ids_to_test = ["B", "C", "E"] # Subjects with three recordings
# Loop over all found results
for subject_id in subject_ids_to_test:
print("\n\n")
print("####################################################")
print(f"# GRID SEARCH RESULTS FOR SUBJECT {subject_id}")
print("####################################################")
print("\n\n")
# Open from file
with open(f"saved_variables/2/samesubject_differentsession/subject{subject_id}/gridsearch_csplda.pickle", 'rb') as f:
grid_search = pickle.load(f)
# Print the results
print(f"Best estimator has accuracy of {np.round(grid_search.best_score_, 4)} +- {np.round(grid_search.cv_results_['std_test_score'][grid_search.best_index_], 4)} with the following parameters")
print(grid_search.best_params_)
# Get grid search results
grid_search_results = pd.DataFrame(grid_search.cv_results_)
# Keep relevant columns and sort on rank
grid_search_results.drop(labels='params', axis=1, inplace= True)
grid_search_results.sort_values(by=['rank_test_score'], inplace=True)
# Display grid search resulst
print("\n\n Top 10 grid search results: ")
display(grid_search_results.head(10))
print("\n\n Worst 10 grid search results: ")
display(grid_search_results.tail(10))
# Display some statistics
print(f"\n\nIn total there are {len(grid_search_results)} different configurations tested.")
max_score = grid_search_results['mean_test_score'].max()
print(f"The best mean test score is {round(max_score, 4)}")
shared_first_place_count = len(grid_search_results[grid_search_results['mean_test_score'].between(max_score, max_score)])
print(f"There are {shared_first_place_count} configurations with this maximum score")
close_first_place_count = len(grid_search_results[grid_search_results['mean_test_score'].between(max_score-0.02, max_score)])
print(f"There are {close_first_place_count} configurations within 0.02 of this maximum score")
# Display statistics for best classifiers
print("\n\nThe describe of the configurations within 0.02 of this maximum score is as follows:")
display(grid_search_results[grid_search_results['mean_test_score'].between(max_score-0.02, max_score)].describe(include="all"))
# Remove unsused variables
del f
del grid_search
del max_score
del shared_first_place_count
del close_first_place_count
del grid_search_results
del subject_ids_to_test
del subject_id
####################################################
# GRID SEARCH RESULTS FOR SUBJECT B
####################################################
Best estimator has accuracy of 0.45 +- 0.0258 with the following parameters
{'CSP__n_components': 10, 'LDA__solver': 'svd', 'LDA__tol': 0.0001}
Top 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_LDA__solver | param_LDA__tol | split0_test_score | split1_test_score | split2_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 34 | 5.325965 | 0.107249 | 0.004832 | 0.000373 | 10 | eigen | NaN | 0.4625 | 0.459375 | 0.434375 | ... | 0.025761 | 1 | 0.518148 | 0.508761 | 0.501252 | 0.515019 | 0.484053 | 0.494059 | 0.503548 | 0.011883 |
| 24 | 5.413773 | 0.036873 | 0.004332 | 0.000471 | 10 | svd | 0.00007 | 0.4625 | 0.459375 | 0.434375 | ... | 0.025761 | 1 | 0.518148 | 0.508135 | 0.503129 | 0.515019 | 0.484053 | 0.494059 | 0.503757 | 0.011803 |
| 23 | 5.456924 | 0.048601 | 0.004832 | 0.000373 | 10 | svd | 0.0004 | 0.4625 | 0.459375 | 0.434375 | ... | 0.025761 | 1 | 0.518148 | 0.508135 | 0.503129 | 0.515019 | 0.484053 | 0.494059 | 0.503757 | 0.011803 |
| 22 | 5.481582 | 0.074412 | 0.005499 | 0.000764 | 10 | svd | 0.001 | 0.4625 | 0.459375 | 0.434375 | ... | 0.025761 | 1 | 0.518148 | 0.508135 | 0.503129 | 0.515019 | 0.484053 | 0.494059 | 0.503757 | 0.011803 |
| 21 | 5.435931 | 0.068638 | 0.004832 | 0.000687 | 10 | svd | 0.00001 | 0.4625 | 0.459375 | 0.434375 | ... | 0.025761 | 1 | 0.518148 | 0.508135 | 0.503129 | 0.515019 | 0.484053 | 0.494059 | 0.503757 | 0.011803 |
| 20 | 5.407107 | 0.060543 | 0.004499 | 0.000500 | 10 | svd | 0.0001 | 0.4625 | 0.459375 | 0.434375 | ... | 0.025761 | 1 | 0.518148 | 0.508135 | 0.503129 | 0.515019 | 0.484053 | 0.494059 | 0.503757 | 0.011803 |
| 33 | 5.501247 | 0.045590 | 0.004832 | 0.000373 | 10 | lsqr | NaN | 0.4625 | 0.459375 | 0.434375 | ... | 0.025761 | 1 | 0.518148 | 0.508761 | 0.501252 | 0.515019 | 0.484053 | 0.494059 | 0.503548 | 0.011883 |
| 32 | 5.401942 | 0.043255 | 0.003166 | 0.000373 | 6 | eigen | NaN | 0.4500 | 0.403125 | 0.406250 | ... | 0.019975 | 8 | 0.482478 | 0.487484 | 0.447434 | 0.482478 | 0.464665 | 0.454659 | 0.469867 | 0.015221 |
| 31 | 5.401275 | 0.045625 | 0.003333 | 0.000472 | 6 | lsqr | NaN | 0.4500 | 0.403125 | 0.406250 | ... | 0.019975 | 8 | 0.482478 | 0.487484 | 0.447434 | 0.482478 | 0.464665 | 0.454659 | 0.469867 | 0.015221 |
| 19 | 5.377282 | 0.043226 | 0.003499 | 0.000500 | 6 | svd | 0.00007 | 0.4500 | 0.403125 | 0.403125 | ... | 0.020556 | 10 | 0.481852 | 0.487484 | 0.447434 | 0.481852 | 0.464665 | 0.454659 | 0.469658 | 0.015050 |
10 rows × 24 columns
Worst 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_LDA__solver | param_LDA__tol | split0_test_score | split1_test_score | split2_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 12 | 5.369285 | 0.046170 | 0.003332 | 0.000471 | 4 | svd | 0.001 | 0.428125 | 0.44375 | 0.365625 | ... | 0.033162 | 22 | 0.426783 | 0.449937 | 0.419274 | 0.427409 | 0.434647 | 0.437148 | 0.432533 | 0.009694 |
| 29 | 5.413272 | 0.059934 | 0.003332 | 0.000471 | 4 | lsqr | NaN | 0.428125 | 0.44375 | 0.365625 | ... | 0.033162 | 22 | 0.426783 | 0.449937 | 0.418648 | 0.428035 | 0.434647 | 0.437148 | 0.432533 | 0.009787 |
| 30 | 5.397610 | 0.065022 | 0.003332 | 0.000471 | 4 | eigen | NaN | 0.428125 | 0.44375 | 0.365625 | ... | 0.033162 | 22 | 0.426783 | 0.449937 | 0.418648 | 0.428035 | 0.434647 | 0.437148 | 0.432533 | 0.009787 |
| 25 | 5.402943 | 0.036143 | 0.002499 | 0.000500 | 2 | lsqr | NaN | 0.390625 | 0.45625 | 0.396875 | ... | 0.044478 | 29 | 0.429912 | 0.439925 | 0.386108 | 0.373592 | 0.419012 | 0.432770 | 0.413553 | 0.024874 |
| 26 | 5.390612 | 0.048714 | 0.002499 | 0.000500 | 2 | eigen | NaN | 0.390625 | 0.45625 | 0.396875 | ... | 0.044478 | 29 | 0.429912 | 0.439925 | 0.386108 | 0.373592 | 0.419012 | 0.432770 | 0.413553 | 0.024874 |
| 1 | 5.622225 | 0.041098 | 0.002331 | 0.000470 | 2 | svd | 0.00001 | 0.390625 | 0.45625 | 0.396875 | ... | 0.044021 | 31 | 0.429912 | 0.439925 | 0.386108 | 0.373592 | 0.419637 | 0.432770 | 0.413657 | 0.024898 |
| 3 | 5.546569 | 0.043613 | 0.002166 | 0.000373 | 2 | svd | 0.0004 | 0.390625 | 0.45625 | 0.396875 | ... | 0.044021 | 31 | 0.429912 | 0.439925 | 0.386108 | 0.373592 | 0.419637 | 0.432770 | 0.413657 | 0.024898 |
| 4 | 5.536681 | 0.070654 | 0.002166 | 0.000372 | 2 | svd | 0.00007 | 0.390625 | 0.45625 | 0.396875 | ... | 0.044021 | 31 | 0.429912 | 0.439925 | 0.386108 | 0.373592 | 0.419637 | 0.432770 | 0.413657 | 0.024898 |
| 2 | 5.489227 | 0.057258 | 0.002333 | 0.000471 | 2 | svd | 0.001 | 0.390625 | 0.45625 | 0.396875 | ... | 0.044021 | 31 | 0.429912 | 0.439925 | 0.386108 | 0.373592 | 0.419637 | 0.432770 | 0.413657 | 0.024898 |
| 0 | 5.790951 | 0.083763 | 0.002333 | 0.000471 | 2 | svd | 0.0001 | 0.390625 | 0.45625 | 0.396875 | ... | 0.044021 | 31 | 0.429912 | 0.439925 | 0.386108 | 0.373592 | 0.419637 | 0.432770 | 0.413657 | 0.024898 |
10 rows × 24 columns
In total there are 35 different configurations tested. The best mean test score is 0.45 There are 7 configurations with this maximum score There are 7 configurations within 0.02 of this maximum score The describe of the configurations within 0.02 of this maximum score is as follows:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_LDA__solver | param_LDA__tol | split0_test_score | split1_test_score | split2_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 7.000000 | 7.000000 | 7.000000 | 7.000000 | 7.0 | 7 | 5.00000 | 7.000000e+00 | 7.000000e+00 | 7.000000e+00 | ... | 7.000000e+00 | 7.0 | 7.000000 | 7.000000 | 7.000000 | 7.000000 | 7.000000e+00 | 7.000000e+00 | 7.000000 | 7.000000 |
| unique | NaN | NaN | NaN | NaN | 1.0 | 3 | 5.00000 | NaN | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| top | NaN | NaN | NaN | NaN | 10.0 | svd | 0.00007 | NaN | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| freq | NaN | NaN | NaN | NaN | 7.0 | 5 | 1.00000 | NaN | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| mean | 5.431790 | 0.063129 | 0.004808 | 0.000506 | NaN | NaN | NaN | 4.625000e-01 | 4.593750e-01 | 4.343750e-01 | ... | 2.576118e-02 | 1.0 | 0.518148 | 0.508314 | 0.502593 | 0.515019 | 4.840525e-01 | 4.940588e-01 | 0.503697 | 0.011826 |
| std | 0.057858 | 0.023519 | 0.000365 | 0.000160 | NaN | NaN | NaN | 5.995890e-17 | 5.995890e-17 | 5.995890e-17 | ... | 7.494862e-18 | 0.0 | 0.000000 | 0.000305 | 0.000916 | 0.000000 | 5.995890e-17 | 5.995890e-17 | 0.000102 | 0.000039 |
| min | 5.325965 | 0.036873 | 0.004332 | 0.000373 | NaN | NaN | NaN | 4.625000e-01 | 4.593750e-01 | 4.343750e-01 | ... | 2.576118e-02 | 1.0 | 0.518148 | 0.508135 | 0.501252 | 0.515019 | 4.840525e-01 | 4.940588e-01 | 0.503548 | 0.011803 |
| 25% | 5.410440 | 0.047096 | 0.004665 | 0.000373 | NaN | NaN | NaN | 4.625000e-01 | 4.593750e-01 | 4.343750e-01 | ... | 2.576118e-02 | 1.0 | 0.518148 | 0.508135 | 0.502190 | 0.515019 | 4.840525e-01 | 4.940588e-01 | 0.503653 | 0.011803 |
| 50% | 5.435931 | 0.060543 | 0.004832 | 0.000471 | NaN | NaN | NaN | 4.625000e-01 | 4.593750e-01 | 4.343750e-01 | ... | 2.576118e-02 | 1.0 | 0.518148 | 0.508135 | 0.503129 | 0.515019 | 4.840525e-01 | 4.940588e-01 | 0.503757 | 0.011803 |
| 75% | 5.469253 | 0.071525 | 0.004832 | 0.000593 | NaN | NaN | NaN | 4.625000e-01 | 4.593750e-01 | 4.343750e-01 | ... | 2.576118e-02 | 1.0 | 0.518148 | 0.508448 | 0.503129 | 0.515019 | 4.840525e-01 | 4.940588e-01 | 0.503757 | 0.011843 |
| max | 5.501247 | 0.107249 | 0.005499 | 0.000764 | NaN | NaN | NaN | 4.625000e-01 | 4.593750e-01 | 4.343750e-01 | ... | 2.576118e-02 | 1.0 | 0.518148 | 0.508761 | 0.503129 | 0.515019 | 4.840525e-01 | 4.940588e-01 | 0.503757 | 0.011883 |
11 rows × 24 columns
####################################################
# GRID SEARCH RESULTS FOR SUBJECT C
####################################################
Best estimator has accuracy of 0.8177 +- 0.0194 with the following parameters
{'CSP__n_components': 10, 'LDA__solver': 'svd', 'LDA__tol': 0.0001}
Top 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_LDA__solver | param_LDA__tol | split0_test_score | split1_test_score | split2_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 34 | 5.372951 | 0.073962 | 0.004665 | 0.000472 | 10 | eigen | NaN | 0.815625 | 0.7875 | 0.853125 | ... | 0.019404 | 1 | 0.828125 | 0.8350 | 0.831875 | 0.828125 | 0.843125 | 0.836875 | 0.833854 | 0.005259 |
| 24 | 5.443928 | 0.036729 | 0.004832 | 0.000373 | 10 | svd | 0.00007 | 0.815625 | 0.7875 | 0.853125 | ... | 0.019404 | 1 | 0.827500 | 0.8350 | 0.831875 | 0.828125 | 0.843125 | 0.836875 | 0.833750 | 0.005376 |
| 23 | 5.451259 | 0.055908 | 0.004665 | 0.000472 | 10 | svd | 0.0004 | 0.815625 | 0.7875 | 0.853125 | ... | 0.019404 | 1 | 0.827500 | 0.8350 | 0.831875 | 0.828125 | 0.843125 | 0.836875 | 0.833750 | 0.005376 |
| 22 | 5.451593 | 0.051584 | 0.004665 | 0.000471 | 10 | svd | 0.001 | 0.815625 | 0.7875 | 0.853125 | ... | 0.019404 | 1 | 0.827500 | 0.8350 | 0.831875 | 0.828125 | 0.843125 | 0.836875 | 0.833750 | 0.005376 |
| 21 | 5.456924 | 0.071654 | 0.004832 | 0.000373 | 10 | svd | 0.00001 | 0.815625 | 0.7875 | 0.853125 | ... | 0.019404 | 1 | 0.827500 | 0.8350 | 0.831875 | 0.828125 | 0.843125 | 0.836875 | 0.833750 | 0.005376 |
| 20 | 5.426434 | 0.043945 | 0.004499 | 0.000500 | 10 | svd | 0.0001 | 0.815625 | 0.7875 | 0.853125 | ... | 0.019404 | 1 | 0.827500 | 0.8350 | 0.831875 | 0.828125 | 0.843125 | 0.836875 | 0.833750 | 0.005376 |
| 33 | 5.458590 | 0.059358 | 0.005332 | 0.000942 | 10 | lsqr | NaN | 0.815625 | 0.7875 | 0.853125 | ... | 0.019404 | 1 | 0.828125 | 0.8350 | 0.831875 | 0.828125 | 0.843125 | 0.836875 | 0.833854 | 0.005259 |
| 32 | 5.425934 | 0.062461 | 0.003332 | 0.000471 | 6 | eigen | NaN | 0.800000 | 0.7500 | 0.840625 | ... | 0.026983 | 8 | 0.806875 | 0.8225 | 0.805000 | 0.808750 | 0.811250 | 0.815000 | 0.811562 | 0.005838 |
| 31 | 5.436431 | 0.059140 | 0.003499 | 0.000764 | 6 | lsqr | NaN | 0.800000 | 0.7500 | 0.840625 | ... | 0.026983 | 8 | 0.806875 | 0.8225 | 0.805000 | 0.808750 | 0.811250 | 0.815000 | 0.811562 | 0.005838 |
| 19 | 5.399942 | 0.064668 | 0.003333 | 0.000471 | 6 | svd | 0.00007 | 0.800000 | 0.7500 | 0.840625 | ... | 0.026983 | 8 | 0.806875 | 0.8225 | 0.805000 | 0.808750 | 0.811250 | 0.815000 | 0.811562 | 0.005838 |
10 rows × 24 columns
Worst 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_LDA__solver | param_LDA__tol | split0_test_score | split1_test_score | split2_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 3 | 5.422935 | 0.050707 | 0.002499 | 4.997651e-04 | 2 | svd | 0.0004 | 0.793750 | 0.740625 | 0.809375 | ... | 0.022438 | 22 | 0.780 | 0.793750 | 0.783750 | 0.788750 | 0.798125 | 0.787500 | 0.788646 | 0.005999 |
| 4 | 5.401108 | 0.049301 | 0.002499 | 7.636033e-04 | 2 | svd | 0.00007 | 0.793750 | 0.740625 | 0.809375 | ... | 0.022438 | 22 | 0.780 | 0.793750 | 0.783750 | 0.788750 | 0.798125 | 0.787500 | 0.788646 | 0.005999 |
| 0 | 5.388612 | 0.044534 | 0.002666 | 7.454265e-04 | 2 | svd | 0.0001 | 0.793750 | 0.740625 | 0.809375 | ... | 0.022438 | 22 | 0.780 | 0.793750 | 0.783750 | 0.788750 | 0.798125 | 0.787500 | 0.788646 | 0.005999 |
| 5 | 5.416604 | 0.052432 | 0.002166 | 3.721891e-04 | 3 | svd | 0.0001 | 0.790625 | 0.737500 | 0.815625 | ... | 0.023576 | 29 | 0.785 | 0.799375 | 0.784375 | 0.795625 | 0.794375 | 0.789375 | 0.791354 | 0.005548 |
| 27 | 5.410772 | 0.058162 | 0.002332 | 4.715110e-04 | 3 | lsqr | NaN | 0.790625 | 0.737500 | 0.815625 | ... | 0.023576 | 29 | 0.785 | 0.800000 | 0.784375 | 0.795000 | 0.794375 | 0.789375 | 0.791354 | 0.005630 |
| 28 | 5.412604 | 0.053445 | 0.002999 | 1.123916e-07 | 3 | eigen | NaN | 0.790625 | 0.737500 | 0.815625 | ... | 0.023576 | 29 | 0.785 | 0.800000 | 0.784375 | 0.795000 | 0.794375 | 0.789375 | 0.791354 | 0.005630 |
| 6 | 5.415437 | 0.061769 | 0.002333 | 4.713142e-04 | 3 | svd | 0.00001 | 0.790625 | 0.737500 | 0.815625 | ... | 0.023576 | 29 | 0.785 | 0.799375 | 0.784375 | 0.795625 | 0.794375 | 0.789375 | 0.791354 | 0.005548 |
| 7 | 5.395943 | 0.049596 | 0.003332 | 1.246938e-03 | 3 | svd | 0.001 | 0.790625 | 0.737500 | 0.815625 | ... | 0.023576 | 29 | 0.785 | 0.799375 | 0.784375 | 0.795625 | 0.794375 | 0.789375 | 0.791354 | 0.005548 |
| 8 | 5.387613 | 0.055018 | 0.002499 | 5.001228e-04 | 3 | svd | 0.0004 | 0.790625 | 0.737500 | 0.815625 | ... | 0.023576 | 29 | 0.785 | 0.799375 | 0.784375 | 0.795625 | 0.794375 | 0.789375 | 0.791354 | 0.005548 |
| 9 | 5.403441 | 0.068141 | 0.002666 | 4.713145e-04 | 3 | svd | 0.00007 | 0.790625 | 0.737500 | 0.815625 | ... | 0.023576 | 29 | 0.785 | 0.799375 | 0.784375 | 0.795625 | 0.794375 | 0.789375 | 0.791354 | 0.005548 |
10 rows × 24 columns
In total there are 35 different configurations tested. The best mean test score is 0.8177 There are 7 configurations with this maximum score There are 14 configurations within 0.02 of this maximum score The describe of the configurations within 0.02 of this maximum score is as follows:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_LDA__solver | param_LDA__tol | split0_test_score | split1_test_score | split2_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 14.000000 | 14.000000 | 14.000000 | 14.000000 | 14.0 | 14 | 10.00000 | 14.000000 | 14.000000 | 14.000000 | ... | 14.000000 | 14.000000 | 14.000000 | 14.000000 | 14.000000 | 14.000000 | 14.000000 | 14.000000 | 14.000000 | 14.000000 |
| unique | NaN | NaN | NaN | NaN | 2.0 | 3 | 5.00000 | NaN | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| top | NaN | NaN | NaN | NaN | 10.0 | svd | 0.00007 | NaN | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| freq | NaN | NaN | NaN | NaN | 7.0 | 10 | 2.00000 | NaN | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| mean | 5.427005 | 0.055739 | 0.004130 | 0.000520 | NaN | NaN | NaN | 0.807813 | 0.768750 | 0.846875 | ... | 0.023194 | 4.500000 | 0.817277 | 0.828750 | 0.818437 | 0.818437 | 0.827187 | 0.825937 | 0.822671 | 0.005590 |
| std | 0.024907 | 0.012719 | 0.000707 | 0.000151 | NaN | NaN | NaN | 0.008107 | 0.019458 | 0.006486 | ... | 0.003932 | 3.632122 | 0.010796 | 0.006486 | 0.013945 | 0.010053 | 0.016539 | 0.011350 | 0.011528 | 0.000260 |
| min | 5.372951 | 0.034784 | 0.003332 | 0.000373 | NaN | NaN | NaN | 0.800000 | 0.750000 | 0.840625 | ... | 0.019404 | 1.000000 | 0.806875 | 0.822500 | 0.805000 | 0.808750 | 0.811250 | 0.815000 | 0.811562 | 0.005259 |
| 25% | 5.411355 | 0.045855 | 0.003499 | 0.000471 | NaN | NaN | NaN | 0.800000 | 0.750000 | 0.840625 | ... | 0.019404 | 1.000000 | 0.806875 | 0.822500 | 0.805000 | 0.808750 | 0.811250 | 0.815000 | 0.811562 | 0.005376 |
| 50% | 5.426184 | 0.059249 | 0.004082 | 0.000472 | NaN | NaN | NaN | 0.807813 | 0.768750 | 0.846875 | ... | 0.023194 | 4.500000 | 0.817187 | 0.828750 | 0.818438 | 0.818437 | 0.827187 | 0.825937 | 0.822656 | 0.005607 |
| 75% | 5.449426 | 0.064116 | 0.004665 | 0.000500 | NaN | NaN | NaN | 0.815625 | 0.787500 | 0.853125 | ... | 0.026983 | 8.000000 | 0.827500 | 0.835000 | 0.831875 | 0.828125 | 0.843125 | 0.836875 | 0.833750 | 0.005838 |
| max | 5.458590 | 0.073962 | 0.005332 | 0.000942 | NaN | NaN | NaN | 0.815625 | 0.787500 | 0.853125 | ... | 0.026983 | 8.000000 | 0.828125 | 0.835000 | 0.831875 | 0.828125 | 0.843125 | 0.836875 | 0.833854 | 0.005838 |
11 rows × 24 columns
####################################################
# GRID SEARCH RESULTS FOR SUBJECT E
####################################################
Best estimator has accuracy of 0.5525 +- 0.0368 with the following parameters
{'CSP__n_components': 10, 'LDA__solver': 'svd', 'LDA__tol': 0.0001}
Top 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_LDA__solver | param_LDA__tol | split0_test_score | split1_test_score | split2_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 24 | 5.404738 | 0.047325 | 0.004666 | 0.000471 | 10 | svd | 0.00007 | 0.487500 | 0.531250 | 0.565625 | ... | 0.036775 | 1 | 0.569818 | 0.583594 | 0.572949 | 0.558824 | 0.576971 | 0.561952 | 0.570685 | 0.008456 |
| 23 | 5.414544 | 0.053654 | 0.004666 | 0.000471 | 10 | svd | 0.0004 | 0.487500 | 0.531250 | 0.565625 | ... | 0.036775 | 1 | 0.569818 | 0.583594 | 0.572949 | 0.558824 | 0.576971 | 0.561952 | 0.570685 | 0.008456 |
| 22 | 5.423074 | 0.018591 | 0.005165 | 0.000373 | 10 | svd | 0.001 | 0.487500 | 0.531250 | 0.565625 | ... | 0.036775 | 1 | 0.569818 | 0.583594 | 0.572949 | 0.558824 | 0.576971 | 0.561952 | 0.570685 | 0.008456 |
| 21 | 5.443097 | 0.061504 | 0.004832 | 0.000373 | 10 | svd | 0.00001 | 0.487500 | 0.531250 | 0.565625 | ... | 0.036775 | 1 | 0.569818 | 0.583594 | 0.572949 | 0.558824 | 0.576971 | 0.561952 | 0.570685 | 0.008456 |
| 20 | 5.403401 | 0.065638 | 0.004833 | 0.000372 | 10 | svd | 0.0001 | 0.487500 | 0.531250 | 0.565625 | ... | 0.036775 | 1 | 0.569818 | 0.583594 | 0.572949 | 0.558824 | 0.576971 | 0.561952 | 0.570685 | 0.008456 |
| 34 | 5.356239 | 0.100794 | 0.004661 | 0.000473 | 10 | eigen | NaN | 0.487500 | 0.531250 | 0.565625 | ... | 0.036625 | 6 | 0.569192 | 0.583594 | 0.571697 | 0.558198 | 0.576971 | 0.561952 | 0.570267 | 0.008569 |
| 33 | 5.434308 | 0.052510 | 0.004665 | 0.000471 | 10 | lsqr | NaN | 0.487500 | 0.531250 | 0.565625 | ... | 0.036625 | 6 | 0.569192 | 0.583594 | 0.571697 | 0.558198 | 0.576971 | 0.561952 | 0.570267 | 0.008569 |
| 19 | 5.386089 | 0.042491 | 0.005498 | 0.004715 | 6 | svd | 0.00007 | 0.459375 | 0.484375 | 0.496875 | ... | 0.036737 | 8 | 0.517846 | 0.538510 | 0.512837 | 0.514393 | 0.512516 | 0.530038 | 0.521023 | 0.009836 |
| 18 | 5.388858 | 0.038988 | 0.003499 | 0.000500 | 6 | svd | 0.0004 | 0.459375 | 0.484375 | 0.496875 | ... | 0.036737 | 8 | 0.517846 | 0.538510 | 0.512837 | 0.514393 | 0.512516 | 0.530038 | 0.521023 | 0.009836 |
| 16 | 5.391821 | 0.041460 | 0.003332 | 0.000471 | 6 | svd | 0.00001 | 0.459375 | 0.484375 | 0.496875 | ... | 0.036737 | 8 | 0.517846 | 0.538510 | 0.512837 | 0.514393 | 0.512516 | 0.530038 | 0.521023 | 0.009836 |
10 rows × 24 columns
Worst 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_LDA__solver | param_LDA__tol | split0_test_score | split1_test_score | split2_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 5 | 5.396208 | 0.064162 | 0.002499 | 4.995666e-04 | 3 | svd | 0.0001 | 0.3750 | 0.459375 | 0.478125 | ... | 0.036877 | 22 | 0.472761 | 0.479649 | 0.479649 | 0.432416 | 0.491865 | 0.436170 | 0.465418 | 0.022741 |
| 28 | 5.384392 | 0.064971 | 0.002833 | 3.724556e-04 | 3 | eigen | NaN | 0.3750 | 0.456250 | 0.478125 | ... | 0.036647 | 27 | 0.472761 | 0.479023 | 0.479649 | 0.432416 | 0.492491 | 0.436170 | 0.465418 | 0.022800 |
| 27 | 5.385055 | 0.035717 | 0.002666 | 4.714547e-04 | 3 | lsqr | NaN | 0.3750 | 0.456250 | 0.478125 | ... | 0.036647 | 27 | 0.472761 | 0.479023 | 0.479649 | 0.432416 | 0.492491 | 0.436170 | 0.465418 | 0.022800 |
| 2 | 5.382403 | 0.035026 | 0.002166 | 3.726688e-04 | 2 | svd | 0.001 | 0.3125 | 0.406250 | 0.453125 | ... | 0.048248 | 29 | 0.433312 | 0.428929 | 0.432060 | 0.436796 | 0.441176 | 0.438048 | 0.435054 | 0.004062 |
| 25 | 5.383334 | 0.041997 | 0.002000 | 1.820952e-07 | 2 | lsqr | NaN | 0.3125 | 0.406250 | 0.453125 | ... | 0.048248 | 29 | 0.432686 | 0.428929 | 0.432060 | 0.436170 | 0.441802 | 0.438673 | 0.435054 | 0.004320 |
| 1 | 5.373709 | 0.053912 | 0.003166 | 1.343046e-03 | 2 | svd | 0.00001 | 0.3125 | 0.406250 | 0.453125 | ... | 0.048248 | 29 | 0.433312 | 0.428929 | 0.432060 | 0.436796 | 0.441176 | 0.438048 | 0.435054 | 0.004062 |
| 3 | 5.392224 | 0.078592 | 0.002333 | 4.711456e-04 | 2 | svd | 0.0004 | 0.3125 | 0.406250 | 0.453125 | ... | 0.048248 | 29 | 0.433312 | 0.428929 | 0.432060 | 0.436796 | 0.441176 | 0.438048 | 0.435054 | 0.004062 |
| 4 | 5.374064 | 0.041719 | 0.002666 | 4.712300e-04 | 2 | svd | 0.00007 | 0.3125 | 0.406250 | 0.453125 | ... | 0.048248 | 29 | 0.433312 | 0.428929 | 0.432060 | 0.436796 | 0.441176 | 0.438048 | 0.435054 | 0.004062 |
| 26 | 5.370033 | 0.051684 | 0.002333 | 1.698825e-03 | 2 | eigen | NaN | 0.3125 | 0.406250 | 0.453125 | ... | 0.048248 | 29 | 0.432686 | 0.428929 | 0.432060 | 0.436170 | 0.441802 | 0.438673 | 0.435054 | 0.004320 |
| 0 | 5.394867 | 0.043142 | 0.002499 | 4.997255e-04 | 2 | svd | 0.0001 | 0.3125 | 0.406250 | 0.453125 | ... | 0.048248 | 29 | 0.433312 | 0.428929 | 0.432060 | 0.436796 | 0.441176 | 0.438048 | 0.435054 | 0.004062 |
10 rows × 24 columns
In total there are 35 different configurations tested. The best mean test score is 0.5525 There are 5 configurations with this maximum score There are 7 configurations within 0.02 of this maximum score The describe of the configurations within 0.02 of this maximum score is as follows:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_LDA__solver | param_LDA__tol | split0_test_score | split1_test_score | split2_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 7.000000 | 7.000000 | 7.000000 | 7.000000 | 7.0 | 7 | 5.00000 | 7.000000e+00 | 7.00000 | 7.000000e+00 | ... | 7.000000 | 7.000000 | 7.000000 | 7.000000 | 7.000000 | 7.000000 | 7.000000 | 7.000000e+00 | 7.000000 | 7.000000 |
| unique | NaN | NaN | NaN | NaN | 1.0 | 3 | 5.00000 | NaN | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| top | NaN | NaN | NaN | NaN | 10.0 | svd | 0.00007 | NaN | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| freq | NaN | NaN | NaN | NaN | 7.0 | 5 | 1.00000 | NaN | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| mean | 5.411343 | 0.057145 | 0.004784 | 0.000429 | NaN | NaN | NaN | 4.875000e-01 | 0.53125 | 5.656250e-01 | ... | 0.036732 | 2.428571 | 0.569640 | 0.583594 | 0.572591 | 0.558645 | 0.576971 | 5.619524e-01 | 0.570566 | 0.008488 |
| std | 0.028363 | 0.024539 | 0.000186 | 0.000053 | NaN | NaN | NaN | 5.995890e-17 | 0.00000 | 1.199178e-16 | ... | 0.000073 | 2.439750 | 0.000306 | 0.000000 | 0.000611 | 0.000305 | 0.000000 | 1.199178e-16 | 0.000204 | 0.000055 |
| min | 5.356239 | 0.018591 | 0.004661 | 0.000372 | NaN | NaN | NaN | 4.875000e-01 | 0.53125 | 5.656250e-01 | ... | 0.036625 | 1.000000 | 0.569192 | 0.583594 | 0.571697 | 0.558198 | 0.576971 | 5.619524e-01 | 0.570267 | 0.008456 |
| 25% | 5.404070 | 0.049918 | 0.004665 | 0.000373 | NaN | NaN | NaN | 4.875000e-01 | 0.53125 | 5.656250e-01 | ... | 0.036700 | 1.000000 | 0.569505 | 0.583594 | 0.572323 | 0.558511 | 0.576971 | 5.619524e-01 | 0.570476 | 0.008456 |
| 50% | 5.414544 | 0.053654 | 0.004666 | 0.000471 | NaN | NaN | NaN | 4.875000e-01 | 0.53125 | 5.656250e-01 | ... | 0.036775 | 1.000000 | 0.569818 | 0.583594 | 0.572949 | 0.558824 | 0.576971 | 5.619524e-01 | 0.570685 | 0.008456 |
| 75% | 5.428691 | 0.063571 | 0.004832 | 0.000471 | NaN | NaN | NaN | 4.875000e-01 | 0.53125 | 5.656250e-01 | ... | 0.036775 | 3.500000 | 0.569818 | 0.583594 | 0.572949 | 0.558824 | 0.576971 | 5.619524e-01 | 0.570685 | 0.008512 |
| max | 5.443097 | 0.100794 | 0.005165 | 0.000473 | NaN | NaN | NaN | 4.875000e-01 | 0.53125 | 5.656250e-01 | ... | 0.036775 | 6.000000 | 0.569818 | 0.583594 | 0.572949 | 0.558824 | 0.576971 | 5.619524e-01 | 0.570685 | 0.008569 |
11 rows × 24 columns
####################################################
# TEST RESULTS FOR BEST FOUND GRID SEARCH
####################################################
# Configure global parameters for all experiments
subject_ids_to_test = ["B", "C", "E"] # Subjects with three recordings
start_offset = -1 # One second before visual queue
end_offset = 1 # One second after visual queue
baseline = (None, 0) # Baseline correction using data before the visual queue
filter_lower_bound = 2 # Filter out any frequency below this
filter_upper_bound = 32 # Filter out any frequency above this
best_found_csp_components = [10, 10 , 10]
best_found_solver = ["svd", "svd", "svd"]
best_found_tol = [0.0001, 0.0001, 0.0001]
# Loop over all found results
for i in range(len(subject_ids_to_test)):
print("\n\n")
print("####################################################")
print(f"# TEST RESULTS FOR SUBJECT {subject_ids_to_test[i]}")
print("####################################################")
print("\n\n")
################# TRAINING DATA #################
with io.capture_output():
# Get all training data (all but last session of participant)
mne_raws = CLA_dataset.get_all_but_last_raw_mne_data_for_subject(subject_id= subject_ids_to_test[i])
# Combine training data into singular mne raw
mne_raw = mne.concatenate_raws(mne_raws)
# Get epochs for all those MNE raws (all training sessions)
mne_epochs = CLA_dataset.get_usefull_epochs_from_raw(mne_raw,
start_offset= start_offset,
end_offset= end_offset,
baseline= baseline)
# Only keep epochs from the MI tasks
mne_epochs = mne_epochs['task/neutral', 'task/left', 'task/right']
# Load epochs into memory
mne_epochs.load_data()
# Get the labels
y_train = mne_epochs.events[:, -1]
# Use a fixed filter
mne_epochs.filter(l_freq= filter_lower_bound,
h_freq= filter_upper_bound,
picks= "all",
phase= "minimum",
fir_window= "blackman",
fir_design= "firwin",
pad= 'median',
n_jobs= -1,
verbose= False)
# Get a half second window
X_train = mne_epochs.get_data(tmin= 0.1, tmax= 0.6)
# Delete resedual vars for training data
del mne_raws
del mne_raw
del mne_epochs
################# TESTING DATA #################
with io.capture_output():
# Get test data
mne_raw = CLA_dataset.get_last_raw_mne_data_for_subject(subject_id= subject_ids_to_test[i])
# Get epochs for test MNE raw
mne_epochs = CLA_dataset.get_usefull_epochs_from_raw(mne_raw,
start_offset= start_offset,
end_offset= end_offset,
baseline= baseline)
# Only keep epochs from the MI tasks
mne_epochs = mne_epochs['task/neutral', 'task/left', 'task/right']
# Load epochs into memory
mne_epochs.load_data()
# Get the labels
y_test = mne_epochs.events[:, -1]
# Use a fixed filter
mne_epochs.filter(l_freq= filter_lower_bound,
h_freq= filter_upper_bound,
picks= "all",
phase= "minimum",
fir_window= "blackman",
fir_design= "firwin",
pad= 'median',
n_jobs= -1,
verbose= False)
# Get a half second window
X_test = mne_epochs.get_data(tmin= 0.1, tmax= 0.6)
# Delete resedual vars for training data
del mne_raw
del mne_epochs
################# FIT AND PREDICT #################
# Make the classifier
csp = CSP(norm_trace=False,
component_order="mutual_info",
cov_est= "epoch",
n_components= best_found_csp_components[i])
lda = LinearDiscriminantAnalysis(shrinkage= None,
priors=[1/3, 1/3, 1/3],
solver= best_found_solver[i],
tol= best_found_tol[i])
# Configure the pipeline
pipeline = Pipeline([('CSP', csp), ('LDA', lda)])
# Fit the pipeline
with io.capture_output():
pipeline.fit(X_train, y_train)
# Get accuracy for single fit
y_pred = pipeline.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
# Print accuracy results and CM
print(f"Test accuracy for subject {subject_ids_to_test[i]}: {np.round(accuracy, 4)}")
ConfusionMatrixDisplay.from_predictions(y_true= y_test, y_pred= y_pred)
plt.show()
# plot CSP patterns estimated on train data for visualization
pipeline['CSP'].plot_patterns(CLA_dataset.get_last_raw_mne_data_for_subject(subject_id= subject_ids_to_test[i]).info, ch_type='eeg', units='Patterns (AU)', size=1.5)
plt.show()
# Remove unsused variables
del subject_ids_to_test
del best_found_csp_components
del best_found_solver
del best_found_tol
del i
del X_test
del y_test
del X_train
del y_train
del csp
del lda
del pipeline
del y_pred
del accuracy
del start_offset
del end_offset
del baseline
del filter_lower_bound
del filter_upper_bound
#################################################### # TEST RESULTS FOR SUBJECT B #################################################### Test accuracy for subject B: 0.4677
Reading 0 ... 667799 = 0.000 ... 3338.995 secs...
#################################################### # TEST RESULTS FOR SUBJECT C #################################################### Test accuracy for subject C: 0.3587
Reading 0 ... 669399 = 0.000 ... 3346.995 secs...
#################################################### # TEST RESULTS FOR SUBJECT E #################################################### Test accuracy for subject E: 0.5518
Reading 0 ... 666999 = 0.000 ... 3334.995 secs...
This experiment works as follows:
####################################################
# GRID SEARCHING BEST PIPELINE FOR EACH SUBJECT
####################################################
# Configure global parameters for all experiments
subject_ids_to_test = ["B", "C", "E"] # Subjects with three recordings
start_offset = -1 # One second before visual queue
end_offset = 1 # One second after visual queue
baseline = (None, 0) # Baseline correction using data before the visual queue
filter_lower_bound = 2 # Filter out any frequency below this
filter_upper_bound = 32 # Filter out any frequency above this
do_experiment = False # Long experiment disabled per default
if do_experiment:
# Loop over all subjects and perform the grid search for finding the best parameters
for subject_id in subject_ids_to_test:
# Get all training data (all but last session of participant)
mne_raws= CLA_dataset.get_all_but_last_raw_mne_data_for_subject(subject_id= subject_id)
# Combine training data into singular mne raw
mne_raw = mne.concatenate_raws(mne_raws)
# Delete all raws since concat changes them
del mne_raws
# Get epochs for all those MNE raws (all training sessions)
mne_epochs = CLA_dataset.get_usefull_epochs_from_raw(mne_raw,
start_offset= start_offset,
end_offset= end_offset,
baseline= baseline)
# Only keep epochs from the MI tasks
mne_epochs = mne_epochs['task/neutral', 'task/left', 'task/right']
# Load epochs into memory
mne_epochs.load_data()
# Get the labels
labels = mne_epochs.events[:, -1]
# Use a fixed filter
mne_epochs.filter(l_freq= filter_lower_bound,
h_freq= filter_upper_bound,
picks= "all",
phase= "minimum",
fir_window= "blackman",
fir_design= "firwin",
pad= 'median',
n_jobs= -1,
verbose= False)
# Get a half second window
mne_epochs_data = mne_epochs.get_data(tmin= 0.1, tmax= 0.6)
# Configure the pipeline components by specifying the default parameters
csp = CSP(norm_trace=False,
component_order="mutual_info",
cov_est= "epoch")
svm = SVC()
# Configure the pipeline
pipeline = Pipeline([('CSP', csp), ('SVM', svm)])
# Configure cross validation to use, more splits then before since we have more data
cv = StratifiedKFold(n_splits= 6,
shuffle= True,
random_state= 2022)
# Configure the hyperparameters to test
# NOTE: these are somewhat limited due to limited computational resources
param_grid = [{
"CSP__n_components": [4, 6, 10],
"SVM__C": [0.01, 0.1, 1, 10, 100],
"SVM__kernel": ['rbf', 'sigmoid'],
"SVM__gamma":['scale', 'auto', 10, 1, 0.1, 0.01, 0.001]}
,{
"CSP__n_components": [4, 6, 10],
"SVM__C": [0.01, 0.1, 1, 10, 100],
"SVM__kernel": ['linear']}]
# Configure the grid search
grid_search = GridSearchCV(estimator= pipeline,
param_grid= param_grid,
scoring= "accuracy",
n_jobs= -1,
refit= False, # We will do this manually
cv= cv,
verbose= 10,
return_train_score= True)
# Do the grid search on the training data
grid_search.fit(X= mne_epochs_data, y= labels)
# Store the results of the grid search
with open(f"saved_variables/2/samesubject_differentsession/subject{subject_id}/gridsearch_cspsvm.pickle", 'wb') as file:
pickle.dump(grid_search, file)
# Delete vars after singular experiment
del mne_raw
del mne_epochs
del mne_epochs_data
del csp
del svm
del pipeline
del labels
del cv
del file
del grid_search
del param_grid
# Delete vars after all experiments
del subject_id
# Del global vars
del subject_ids_to_test
del filter_lower_bound
del filter_upper_bound
del baseline
del do_experiment
del end_offset
del start_offset
The CV results are based on the training set alone and thus only look at the first two sessions. The test result is for a new, unseen session and thus scores are expected to differ.
| Subject | CSP + SVM: cross validation accuracy | CSP + SVM: test split accuracy | Config |
|---|---|---|---|
| B | 0.4625 +- 0.0276 | 0.4677 | 10 CSP components | rbf SVM with C 10 and gamma 0.01 |
| C | 0.8338 +- 0.0213 | 0.3754 | 10 CSP components | rbf SVM with C 1 and gamma auto |
| E | 0.5816 +- 0.0255 | 0.3895 | 10 CSP components | rbf SVM with C 1 and gamma scale |
####################################################
# GRID SEARCH RESULTS
####################################################
# Configure global parameters for all experiments
subject_ids_to_test = ["B", "C", "E"] # Subjects with three recordings
# Loop over all found results
for subject_id in subject_ids_to_test:
print("\n\n")
print("####################################################")
print(f"# GRID SEARCH RESULTS FOR SUBJECT {subject_id}")
print("####################################################")
print("\n\n")
# Open from file
with open(f"saved_variables/2/samesubject_differentsession/subject{subject_id}/gridsearch_cspsvm.pickle", 'rb') as f:
grid_search = pickle.load(f)
# Print the results
print(f"Best estimator has accuracy of {np.round(grid_search.best_score_, 4)} +- {np.round(grid_search.cv_results_['std_test_score'][grid_search.best_index_], 4)} with the following parameters")
print(grid_search.best_params_)
# Get grid search results
grid_search_results = pd.DataFrame(grid_search.cv_results_)
# Keep relevant columns and sort on rank
grid_search_results.drop(labels='params', axis=1, inplace= True)
grid_search_results.sort_values(by=['rank_test_score'], inplace=True)
# Display grid search resulst
print("\n\n Top 10 grid search results: ")
display(grid_search_results.head(10))
print("\n\n Worst 10 grid search results: ")
display(grid_search_results.tail(10))
# Display some statistics
print(f"\n\nIn total there are {len(grid_search_results)} different configurations tested.")
max_score = grid_search_results['mean_test_score'].max()
print(f"The best mean test score is {round(max_score, 4)}")
shared_first_place_count = len(grid_search_results[grid_search_results['mean_test_score'].between(max_score, max_score)])
print(f"There are {shared_first_place_count} configurations with this maximum score")
close_first_place_count = len(grid_search_results[grid_search_results['mean_test_score'].between(max_score-0.02, max_score)])
print(f"There are {close_first_place_count} configurations within 0.02 of this maximum score")
# Display statistics for best classifiers
print("\n\nThe describe of the configurations within 0.02 of this maximum score is as follows:")
display(grid_search_results[grid_search_results['mean_test_score'].between(max_score-0.02, max_score)].describe(include="all"))
# Remove unsused variables
del f
del grid_search
del max_score
del shared_first_place_count
del close_first_place_count
del grid_search_results
del subject_ids_to_test
del subject_id
####################################################
# GRID SEARCH RESULTS FOR SUBJECT B
####################################################
Best estimator has accuracy of 0.4625 +- 0.0276 with the following parameters
{'CSP__n_components': 10, 'SVM__C': 10, 'SVM__gamma': 0.01, 'SVM__kernel': 'rbf'}
Top 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_SVM__C | param_SVM__gamma | param_SVM__kernel | split0_test_score | split1_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 192 | 5.568229 | 0.063741 | 0.054359 | 0.004185 | 10 | 10 | 0.01 | rbf | 0.481250 | 0.462500 | ... | 0.027577 | 1 | 0.532541 | 0.514393 | 0.505632 | 0.526909 | 0.505316 | 0.495935 | 0.513454 | 0.012784 |
| 208 | 5.555005 | 0.064429 | 0.053316 | 0.001699 | 10 | 100 | 0.001 | rbf | 0.465625 | 0.453125 | ... | 0.027582 | 2 | 0.528160 | 0.509387 | 0.504380 | 0.513141 | 0.495935 | 0.490932 | 0.506989 | 0.012101 |
| 224 | 6.946046 | 0.085665 | 0.013663 | 0.001490 | 10 | 100 | NaN | linear | 0.456250 | 0.453125 | ... | 0.035536 | 3 | 0.531289 | 0.505006 | 0.496871 | 0.513767 | 0.490306 | 0.489681 | 0.504487 | 0.014623 |
| 223 | 5.712160 | 0.093627 | 0.013874 | 0.000405 | 10 | 10 | NaN | linear | 0.453125 | 0.456250 | ... | 0.035536 | 3 | 0.531289 | 0.507509 | 0.496871 | 0.514393 | 0.488430 | 0.489681 | 0.504696 | 0.015058 |
| 222 | 5.512335 | 0.057571 | 0.014495 | 0.000957 | 10 | 1 | NaN | linear | 0.456250 | 0.453125 | ... | 0.033869 | 5 | 0.534418 | 0.505632 | 0.492491 | 0.513767 | 0.490932 | 0.490932 | 0.504695 | 0.015782 |
| 221 | 5.527065 | 0.066557 | 0.014496 | 0.000957 | 10 | 0.1 | NaN | linear | 0.459375 | 0.456250 | ... | 0.033994 | 6 | 0.530038 | 0.506884 | 0.500000 | 0.511264 | 0.495310 | 0.490932 | 0.505738 | 0.012801 |
| 209 | 5.585397 | 0.078837 | 0.021660 | 0.001490 | 10 | 100 | 0.001 | sigmoid | 0.459375 | 0.456250 | ... | 0.033994 | 6 | 0.530663 | 0.506884 | 0.500000 | 0.510013 | 0.495935 | 0.490932 | 0.505738 | 0.012839 |
| 206 | 5.600729 | 0.075998 | 0.054316 | 0.003543 | 10 | 100 | 0.01 | rbf | 0.468750 | 0.425000 | ... | 0.027705 | 8 | 0.554443 | 0.533792 | 0.541302 | 0.540676 | 0.533458 | 0.522201 | 0.537645 | 0.009795 |
| 178 | 5.528270 | 0.085942 | 0.054982 | 0.002645 | 10 | 1 | 0.01 | rbf | 0.481250 | 0.456250 | ... | 0.029547 | 9 | 0.523154 | 0.505632 | 0.495620 | 0.506258 | 0.490932 | 0.494684 | 0.502713 | 0.010743 |
| 170 | 5.515809 | 0.097320 | 0.055316 | 0.001247 | 10 | 1 | auto | rbf | 0.456250 | 0.425000 | ... | 0.026014 | 10 | 0.566333 | 0.538798 | 0.545056 | 0.541302 | 0.534084 | 0.529706 | 0.542547 | 0.011720 |
10 rows × 25 columns
Worst 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_SVM__C | param_SVM__gamma | param_SVM__kernel | split0_test_score | split1_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 21 | 5.556853 | 0.094353 | 0.023493 | 0.001707 | 4 | 0.1 | 1 | sigmoid | 0.296875 | 0.312500 | ... | 0.020845 | 216 | 0.287860 | 0.317272 | 0.282854 | 0.279725 | 0.285178 | 0.322076 | 0.295827 | 0.017095 |
| 159 | 5.541945 | 0.099997 | 0.019994 | 0.000577 | 10 | 0.1 | 10 | sigmoid | 0.303125 | 0.318750 | ... | 0.016170 | 217 | 0.265332 | 0.298498 | 0.280976 | 0.284105 | 0.305191 | 0.286429 | 0.286755 | 0.012762 |
| 103 | 5.508948 | 0.091336 | 0.016995 | 0.001633 | 6 | 1 | 10 | sigmoid | 0.259375 | 0.325000 | ... | 0.021060 | 218 | 0.289737 | 0.312265 | 0.279099 | 0.318523 | 0.292683 | 0.298937 | 0.298541 | 0.013402 |
| 47 | 5.486405 | 0.066430 | 0.016661 | 0.001598 | 4 | 10 | 10 | sigmoid | 0.268750 | 0.271875 | ... | 0.024305 | 219 | 0.287860 | 0.271589 | 0.280976 | 0.319775 | 0.276423 | 0.322076 | 0.293117 | 0.020275 |
| 91 | 5.567795 | 0.061089 | 0.025658 | 0.000942 | 6 | 0.1 | 1 | sigmoid | 0.296875 | 0.278125 | ... | 0.012310 | 220 | 0.288486 | 0.304130 | 0.282228 | 0.280350 | 0.278299 | 0.317699 | 0.291865 | 0.014369 |
| 161 | 5.620501 | 0.103255 | 0.029324 | 0.001374 | 10 | 0.1 | 1 | sigmoid | 0.300000 | 0.275000 | ... | 0.012355 | 221 | 0.290989 | 0.270964 | 0.277847 | 0.289111 | 0.303315 | 0.303315 | 0.289257 | 0.011994 |
| 61 | 5.504968 | 0.055060 | 0.016661 | 0.001105 | 4 | 100 | 10 | sigmoid | 0.268750 | 0.271875 | ... | 0.023402 | 222 | 0.287860 | 0.270964 | 0.280976 | 0.319149 | 0.276423 | 0.322702 | 0.293012 | 0.020402 |
| 33 | 5.499028 | 0.090676 | 0.016828 | 0.001213 | 4 | 1 | 10 | sigmoid | 0.265625 | 0.271875 | ... | 0.024798 | 223 | 0.290363 | 0.267209 | 0.282854 | 0.318523 | 0.275797 | 0.320200 | 0.292491 | 0.020251 |
| 19 | 5.530873 | 0.085733 | 0.019660 | 0.001490 | 4 | 0.1 | 10 | sigmoid | 0.268750 | 0.306250 | ... | 0.012291 | 224 | 0.278473 | 0.264706 | 0.286608 | 0.311014 | 0.272045 | 0.313321 | 0.287694 | 0.018530 |
| 89 | 5.538276 | 0.091236 | 0.018661 | 0.001598 | 6 | 0.1 | 10 | sigmoid | 0.262500 | 0.268750 | ... | 0.023116 | 225 | 0.284731 | 0.289111 | 0.275344 | 0.278473 | 0.293934 | 0.272045 | 0.282273 | 0.007701 |
10 rows × 25 columns
In total there are 225 different configurations tested. The best mean test score is 0.4625 There are 1 configurations with this maximum score There are 12 configurations within 0.02 of this maximum score The describe of the configurations within 0.02 of this maximum score is as follows:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_SVM__C | param_SVM__gamma | param_SVM__kernel | split0_test_score | split1_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 12.000000 | 12.000000 | 12.000000 | 12.000000 | 12.0 | 12.0 | 8.00 | 12 | 12.000000 | 12.000000 | ... | 12.000000 | 12.000000 | 12.000000 | 12.000000 | 12.000000 | 12.000000 | 12.000000 | 12.000000 | 12.000000 | 12.000000 |
| unique | NaN | NaN | NaN | NaN | 1.0 | 4.0 | 4.00 | 3 | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| top | NaN | NaN | NaN | NaN | 10.0 | 100.0 | 0.01 | rbf | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| freq | NaN | NaN | NaN | NaN | 12.0 | 4.0 | 4.00 | 6 | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| mean | 5.681793 | 0.076658 | 0.035399 | 0.001759 | NaN | NaN | NaN | NaN | 0.462240 | 0.448177 | ... | 0.030853 | 6.083333 | 0.538277 | 0.514758 | 0.510117 | 0.520338 | 0.503961 | 0.500573 | 0.514671 | 0.012779 |
| std | 0.401739 | 0.015399 | 0.020002 | 0.001184 | NaN | NaN | NaN | NaN | 0.010012 | 0.014193 | ... | 0.003774 | 3.315483 | 0.015063 | 0.013794 | 0.020643 | 0.013412 | 0.018542 | 0.016270 | 0.016083 | 0.001751 |
| min | 5.512335 | 0.053502 | 0.013663 | 0.000373 | NaN | NaN | NaN | NaN | 0.453125 | 0.425000 | ... | 0.026014 | 1.000000 | 0.523154 | 0.504380 | 0.492491 | 0.506258 | 0.488430 | 0.489681 | 0.502713 | 0.009795 |
| 25% | 5.527968 | 0.064257 | 0.014496 | 0.000957 | NaN | NaN | NaN | NaN | 0.456250 | 0.446094 | ... | 0.027581 | 3.000000 | 0.530507 | 0.505632 | 0.496871 | 0.511264 | 0.490932 | 0.490932 | 0.504696 | 0.011720 |
| 50% | 5.561617 | 0.077418 | 0.037488 | 0.001490 | NaN | NaN | NaN | NaN | 0.457812 | 0.454688 | ... | 0.031210 | 6.000000 | 0.531289 | 0.507196 | 0.500000 | 0.513767 | 0.495622 | 0.491245 | 0.505738 | 0.012792 |
| 75% | 5.589230 | 0.087863 | 0.054327 | 0.002247 | NaN | NaN | NaN | NaN | 0.466406 | 0.456250 | ... | 0.033994 | 9.250000 | 0.539424 | 0.519243 | 0.514549 | 0.530350 | 0.512351 | 0.502502 | 0.519502 | 0.013693 |
| max | 6.946046 | 0.097320 | 0.055316 | 0.004185 | NaN | NaN | NaN | NaN | 0.481250 | 0.462500 | ... | 0.035536 | 10.000000 | 0.566333 | 0.538798 | 0.545056 | 0.541302 | 0.534084 | 0.529706 | 0.542547 | 0.015782 |
11 rows × 25 columns
####################################################
# GRID SEARCH RESULTS FOR SUBJECT C
####################################################
Best estimator has accuracy of 0.8339 +- 0.0213 with the following parameters
{'CSP__n_components': 10, 'SVM__C': 1, 'SVM__gamma': 'auto', 'SVM__kernel': 'rbf'}
Top 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_SVM__C | param_SVM__gamma | param_SVM__kernel | split0_test_score | split1_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 170 | 5.465572 | 0.048488 | 0.030657 | 0.000942 | 10 | 1 | auto | rbf | 0.837500 | 0.806250 | ... | 0.021303 | 1 | 0.855000 | 0.853125 | 0.851250 | 0.844375 | 0.864375 | 0.858125 | 0.854375 | 0.006134 |
| 176 | 5.503598 | 0.048057 | 0.031324 | 0.000942 | 10 | 1 | 0.1 | rbf | 0.837500 | 0.806250 | ... | 0.021303 | 1 | 0.855000 | 0.853125 | 0.851250 | 0.844375 | 0.864375 | 0.858125 | 0.854375 | 0.006134 |
| 168 | 5.448316 | 0.043673 | 0.029990 | 0.000816 | 10 | 1 | scale | rbf | 0.840625 | 0.800000 | ... | 0.018170 | 3 | 0.874375 | 0.870625 | 0.869375 | 0.857500 | 0.880625 | 0.868750 | 0.870208 | 0.006957 |
| 192 | 5.440268 | 0.065889 | 0.029324 | 0.001247 | 10 | 10 | 0.01 | rbf | 0.828125 | 0.793750 | ... | 0.021607 | 4 | 0.840625 | 0.843125 | 0.839375 | 0.831875 | 0.851250 | 0.846250 | 0.842083 | 0.006002 |
| 206 | 5.447359 | 0.053554 | 0.026159 | 0.000687 | 10 | 100 | 0.01 | rbf | 0.837500 | 0.800000 | ... | 0.017647 | 5 | 0.859375 | 0.855625 | 0.851875 | 0.843750 | 0.865625 | 0.853125 | 0.854896 | 0.006735 |
| 154 | 5.475635 | 0.060550 | 0.043653 | 0.001374 | 10 | 0.1 | scale | rbf | 0.834375 | 0.803125 | ... | 0.024324 | 6 | 0.835000 | 0.851250 | 0.843750 | 0.823125 | 0.848750 | 0.845000 | 0.841146 | 0.009523 |
| 184 | 5.450008 | 0.049254 | 0.027158 | 0.002733 | 10 | 10 | auto | rbf | 0.840625 | 0.784375 | ... | 0.024766 | 7 | 0.890625 | 0.886250 | 0.871250 | 0.870000 | 0.887500 | 0.880000 | 0.880937 | 0.007953 |
| 190 | 5.449612 | 0.059454 | 0.028657 | 0.004817 | 10 | 10 | 0.1 | rbf | 0.840625 | 0.784375 | ... | 0.024766 | 7 | 0.890625 | 0.886250 | 0.871250 | 0.870000 | 0.887500 | 0.880000 | 0.880937 | 0.007953 |
| 221 | 5.440358 | 0.062691 | 0.010497 | 0.000764 | 10 | 0.1 | NaN | linear | 0.825000 | 0.781250 | ... | 0.021092 | 9 | 0.829375 | 0.838125 | 0.834375 | 0.823750 | 0.843750 | 0.841250 | 0.835104 | 0.006879 |
| 223 | 5.502102 | 0.066761 | 0.009330 | 0.000745 | 10 | 10 | NaN | linear | 0.815625 | 0.781250 | ... | 0.021399 | 9 | 0.825625 | 0.839375 | 0.835625 | 0.825625 | 0.851250 | 0.846875 | 0.837396 | 0.009712 |
10 rows × 25 columns
Worst 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_SVM__C | param_SVM__gamma | param_SVM__kernel | split0_test_score | split1_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 10 | 5.501662 | 0.042406 | 0.051650 | 0.000745 | 4 | 0.01 | 0.01 | rbf | 0.337500 | 0.33750 | ... | 5.551115e-17 | 202 | 0.3375 | 0.337500 | 0.337500 | 0.337500 | 0.3375 | 0.337500 | 0.337500 | 5.551115e-17 |
| 12 | 5.498958 | 0.082246 | 0.053861 | 0.002197 | 4 | 0.01 | 0.001 | rbf | 0.337500 | 0.33750 | ... | 5.551115e-17 | 202 | 0.3375 | 0.337500 | 0.337500 | 0.337500 | 0.3375 | 0.337500 | 0.337500 | 5.551115e-17 |
| 82 | 5.487833 | 0.079256 | 0.053316 | 0.000745 | 6 | 0.01 | 0.001 | rbf | 0.337500 | 0.33750 | ... | 5.551115e-17 | 202 | 0.3375 | 0.337500 | 0.337500 | 0.337500 | 0.3375 | 0.337500 | 0.337500 | 5.551115e-17 |
| 13 | 5.448689 | 0.054929 | 0.018161 | 0.000373 | 4 | 0.01 | 0.001 | sigmoid | 0.337500 | 0.33750 | ... | 5.551115e-17 | 202 | 0.3375 | 0.337500 | 0.337500 | 0.337500 | 0.3375 | 0.337500 | 0.337500 | 5.551115e-17 |
| 96 | 5.534901 | 0.058348 | 0.053150 | 0.000373 | 6 | 0.1 | 0.001 | rbf | 0.337500 | 0.33750 | ... | 5.551115e-17 | 202 | 0.3375 | 0.337500 | 0.337500 | 0.337500 | 0.3375 | 0.337500 | 0.337500 | 5.551115e-17 |
| 81 | 5.487553 | 0.062595 | 0.019161 | 0.000372 | 6 | 0.01 | 0.01 | sigmoid | 0.337500 | 0.33750 | ... | 5.551115e-17 | 202 | 0.3375 | 0.337500 | 0.337500 | 0.337500 | 0.3375 | 0.337500 | 0.337500 | 5.551115e-17 |
| 97 | 5.510097 | 0.058802 | 0.018994 | 0.000578 | 6 | 0.1 | 0.001 | sigmoid | 0.337500 | 0.33750 | ... | 5.551115e-17 | 202 | 0.3375 | 0.337500 | 0.337500 | 0.337500 | 0.3375 | 0.337500 | 0.337500 | 5.551115e-17 |
| 27 | 5.477552 | 0.058636 | 0.018827 | 0.002409 | 4 | 0.1 | 0.001 | sigmoid | 0.337500 | 0.33750 | ... | 5.551115e-17 | 202 | 0.3375 | 0.337500 | 0.337500 | 0.337500 | 0.3375 | 0.337500 | 0.337500 | 5.551115e-17 |
| 11 | 5.490853 | 0.042248 | 0.020327 | 0.004345 | 4 | 0.01 | 0.01 | sigmoid | 0.337500 | 0.33750 | ... | 5.551115e-17 | 202 | 0.3375 | 0.337500 | 0.337500 | 0.337500 | 0.3375 | 0.337500 | 0.337500 | 5.551115e-17 |
| 19 | 5.521237 | 0.092961 | 0.020827 | 0.002671 | 4 | 0.1 | 10 | sigmoid | 0.384375 | 0.46875 | ... | 8.082823e-02 | 225 | 0.3650 | 0.399375 | 0.265625 | 0.215625 | 0.2750 | 0.373125 | 0.315625 | 6.696839e-02 |
10 rows × 25 columns
In total there are 225 different configurations tested. The best mean test score is 0.8339 There are 2 configurations with this maximum score There are 26 configurations within 0.02 of this maximum score The describe of the configurations within 0.02 of this maximum score is as follows:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_SVM__C | param_SVM__gamma | param_SVM__kernel | split0_test_score | split1_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 26.000000 | 26.000000 | 26.000000 | 2.600000e+01 | 26.0 | 26.0 | 21.00 | 26 | 26.000000 | 26.000000 | ... | 26.000000 | 26.000000 | 26.000000 | 26.000000 | 26.000000 | 26.000000 | 26.000000 | 26.000000 | 26.000000 | 26.000000 |
| unique | NaN | NaN | NaN | NaN | 1.0 | 5.0 | 6.00 | 3 | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| top | NaN | NaN | NaN | NaN | 10.0 | 10.0 | 0.01 | rbf | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| freq | NaN | NaN | NaN | NaN | 26.0 | 8.0 | 6.00 | 16 | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| mean | 5.477356 | 0.059524 | 0.027154 | 1.673842e-03 | NaN | NaN | NaN | NaN | 0.822476 | 0.791106 | ... | 0.019908 | 12.692308 | 0.847716 | 0.852187 | 0.848413 | 0.838173 | 0.859880 | 0.855096 | 0.850244 | 0.007584 |
| std | 0.069465 | 0.010608 | 0.013542 | 1.671689e-03 | NaN | NaN | NaN | NaN | 0.011892 | 0.010372 | ... | 0.003035 | 7.556556 | 0.031903 | 0.026965 | 0.024916 | 0.029282 | 0.024856 | 0.025193 | 0.027044 | 0.001589 |
| min | 5.431840 | 0.043673 | 0.008540 | 7.240322e-07 | NaN | NaN | NaN | NaN | 0.796875 | 0.775000 | ... | 0.014254 | 1.000000 | 0.823125 | 0.830625 | 0.829375 | 0.812500 | 0.839375 | 0.835625 | 0.828750 | 0.004152 |
| 25% | 5.447602 | 0.051378 | 0.013246 | 7.449866e-04 | NaN | NaN | NaN | NaN | 0.815625 | 0.781250 | ... | 0.017514 | 7.000000 | 0.827500 | 0.836406 | 0.834375 | 0.822188 | 0.843750 | 0.840313 | 0.834193 | 0.006714 |
| 50% | 5.455959 | 0.058473 | 0.028598 | 9.711213e-04 | NaN | NaN | NaN | NaN | 0.821875 | 0.793750 | ... | 0.020578 | 11.500000 | 0.830000 | 0.839063 | 0.837813 | 0.825313 | 0.849375 | 0.845625 | 0.837708 | 0.007953 |
| 75% | 5.484652 | 0.062394 | 0.039154 | 1.673418e-03 | NaN | NaN | NaN | NaN | 0.832812 | 0.799219 | ... | 0.021998 | 19.000000 | 0.858281 | 0.855000 | 0.851719 | 0.844375 | 0.865312 | 0.858125 | 0.854766 | 0.008785 |
| max | 5.797031 | 0.086795 | 0.051691 | 6.469727e-03 | NaN | NaN | NaN | NaN | 0.840625 | 0.812500 | ... | 0.024766 | 26.000000 | 0.943750 | 0.933750 | 0.933125 | 0.933125 | 0.936875 | 0.941250 | 0.936979 | 0.010251 |
11 rows × 25 columns
####################################################
# GRID SEARCH RESULTS FOR SUBJECT E
####################################################
Best estimator has accuracy of 0.5816 +- 0.0255 with the following parameters
{'CSP__n_components': 10, 'SVM__C': 1, 'SVM__gamma': 'scale', 'SVM__kernel': 'rbf'}
Top 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_SVM__C | param_SVM__gamma | param_SVM__kernel | split0_test_score | split1_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 168 | 5.482249 | 0.029599 | 0.050150 | 0.001950 | 10 | 1 | scale | rbf | 0.568750 | 0.546875 | ... | 0.025547 | 1 | 0.663745 | 0.673763 | 0.653726 | 0.660826 | 0.670839 | 0.657697 | 0.663433 | 0.007019 |
| 206 | 5.616039 | 0.059804 | 0.048818 | 0.002266 | 10 | 100 | 0.01 | rbf | 0.578125 | 0.550000 | ... | 0.035902 | 2 | 0.632436 | 0.647464 | 0.619912 | 0.638298 | 0.634543 | 0.634543 | 0.634533 | 0.008160 |
| 176 | 5.504242 | 0.047538 | 0.049484 | 0.001117 | 10 | 1 | 0.1 | rbf | 0.559375 | 0.537500 | ... | 0.029201 | 3 | 0.639950 | 0.652473 | 0.633688 | 0.642053 | 0.642678 | 0.636421 | 0.641210 | 0.005928 |
| 170 | 5.488913 | 0.047904 | 0.049317 | 0.002134 | 10 | 1 | auto | rbf | 0.559375 | 0.537500 | ... | 0.029201 | 3 | 0.639950 | 0.652473 | 0.633688 | 0.642053 | 0.642678 | 0.636421 | 0.641210 | 0.005928 |
| 154 | 5.508740 | 0.065091 | 0.055316 | 0.002494 | 10 | 0.1 | scale | rbf | 0.531250 | 0.562500 | ... | 0.022409 | 5 | 0.602379 | 0.626800 | 0.605510 | 0.615144 | 0.615144 | 0.603254 | 0.611372 | 0.008637 |
| 192 | 5.513239 | 0.031337 | 0.049651 | 0.001105 | 10 | 10 | 0.01 | rbf | 0.553125 | 0.546875 | ... | 0.029420 | 6 | 0.611772 | 0.618660 | 0.607389 | 0.607009 | 0.613892 | 0.599499 | 0.609704 | 0.006039 |
| 162 | 5.516238 | 0.034486 | 0.056316 | 0.002493 | 10 | 0.1 | 0.1 | rbf | 0.509375 | 0.568750 | ... | 0.026863 | 7 | 0.577959 | 0.614277 | 0.601127 | 0.589487 | 0.595119 | 0.586984 | 0.594159 | 0.011474 |
| 156 | 5.518404 | 0.058280 | 0.058315 | 0.002285 | 10 | 0.1 | auto | rbf | 0.509375 | 0.568750 | ... | 0.026863 | 7 | 0.577959 | 0.614277 | 0.601127 | 0.589487 | 0.595119 | 0.586984 | 0.594159 | 0.011474 |
| 190 | 5.503409 | 0.045187 | 0.046318 | 0.001795 | 10 | 10 | 0.1 | rbf | 0.562500 | 0.562500 | ... | 0.013333 | 9 | 0.697558 | 0.701315 | 0.700063 | 0.695244 | 0.693992 | 0.681477 | 0.694941 | 0.006532 |
| 184 | 5.517238 | 0.039345 | 0.046985 | 0.001000 | 10 | 10 | auto | rbf | 0.562500 | 0.562500 | ... | 0.013333 | 9 | 0.697558 | 0.701315 | 0.700063 | 0.695244 | 0.693992 | 0.681477 | 0.694941 | 0.006532 |
10 rows × 25 columns
Worst 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_SVM__C | param_SVM__gamma | param_SVM__kernel | split0_test_score | split1_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 47 | 5.462255 | 0.066934 | 0.018328 | 0.000943 | 4 | 10 | 10 | sigmoid | 0.309375 | 0.281250 | ... | 0.066304 | 215 | 0.278647 | 0.288666 | 0.393863 | 0.283479 | 0.285357 | 0.296621 | 0.304439 | 0.040364 |
| 33 | 6.413369 | 0.119142 | 0.019494 | 0.001708 | 4 | 1 | 10 | sigmoid | 0.309375 | 0.278125 | ... | 0.067198 | 217 | 0.278647 | 0.286788 | 0.395116 | 0.282228 | 0.284731 | 0.296621 | 0.304022 | 0.041111 |
| 29 | 5.460122 | 0.053549 | 0.016828 | 0.001950 | 4 | 1 | scale | sigmoid | 0.262500 | 0.306250 | ... | 0.036929 | 218 | 0.296807 | 0.300564 | 0.417032 | 0.287234 | 0.299124 | 0.304130 | 0.317482 | 0.044822 |
| 19 | 5.495521 | 0.033737 | 0.018661 | 0.000745 | 4 | 0.1 | 10 | sigmoid | 0.309375 | 0.281250 | ... | 0.052732 | 219 | 0.279274 | 0.292423 | 0.385097 | 0.280976 | 0.285982 | 0.294118 | 0.302978 | 0.037122 |
| 57 | 5.498077 | 0.058346 | 0.017828 | 0.002671 | 4 | 100 | scale | sigmoid | 0.253125 | 0.306250 | ... | 0.043852 | 220 | 0.302442 | 0.303694 | 0.420789 | 0.288486 | 0.295369 | 0.302253 | 0.318839 | 0.045897 |
| 21 | 5.525791 | 0.046490 | 0.021494 | 0.001708 | 4 | 0.1 | 1 | sigmoid | 0.293750 | 0.293750 | ... | 0.046362 | 221 | 0.277395 | 0.284283 | 0.377583 | 0.285982 | 0.285357 | 0.292866 | 0.300578 | 0.034730 |
| 43 | 5.450592 | 0.040709 | 0.015662 | 0.000745 | 4 | 10 | scale | sigmoid | 0.253125 | 0.300000 | ... | 0.043362 | 222 | 0.298059 | 0.302442 | 0.421415 | 0.287860 | 0.309136 | 0.299750 | 0.319777 | 0.045890 |
| 49 | 5.556558 | 0.046089 | 0.020827 | 0.001343 | 4 | 10 | 1 | sigmoid | 0.287500 | 0.290625 | ... | 0.041935 | 223 | 0.273012 | 0.286162 | 0.401378 | 0.285982 | 0.287860 | 0.291615 | 0.304335 | 0.043777 |
| 63 | 5.492245 | 0.055231 | 0.020494 | 0.001384 | 4 | 100 | 1 | sigmoid | 0.287500 | 0.287500 | ... | 0.043702 | 224 | 0.273638 | 0.293676 | 0.400751 | 0.284731 | 0.288486 | 0.291615 | 0.305483 | 0.043092 |
| 35 | 5.905614 | 0.185749 | 0.021659 | 0.001885 | 4 | 1 | 1 | sigmoid | 0.287500 | 0.290625 | ... | 0.044103 | 225 | 0.273012 | 0.287414 | 0.400751 | 0.286608 | 0.284105 | 0.294118 | 0.304335 | 0.043573 |
10 rows × 25 columns
In total there are 225 different configurations tested. The best mean test score is 0.5816 There are 1 configurations with this maximum score There are 10 configurations within 0.02 of this maximum score The describe of the configurations within 0.02 of this maximum score is as follows:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_SVM__C | param_SVM__gamma | param_SVM__kernel | split0_test_score | split1_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 10.000000 | 10.000000 | 10.000000 | 10.000000 | 10.0 | 10.0 | 10.0 | 10 | 10.000000 | 10.000000 | ... | 10.000000 | 10.000000 | 10.000000 | 10.000000 | 10.000000 | 10.000000 | 10.000000 | 10.000000 | 10.000000 | 10.000000 |
| unique | NaN | NaN | NaN | NaN | 1.0 | 4.0 | 4.0 | 1 | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| top | NaN | NaN | NaN | NaN | 10.0 | 1.0 | 0.1 | rbf | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| freq | NaN | NaN | NaN | NaN | 10.0 | 3.0 | 3.0 | 10 | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| mean | 5.516871 | 0.045857 | 0.051067 | 0.001864 | NaN | NaN | NaN | NaN | 0.549375 | 0.554375 | ... | 0.025207 | 5.200000 | 0.634126 | 0.650282 | 0.635629 | 0.637484 | 0.639800 | 0.630476 | 0.637966 | 0.007772 |
| std | 0.036875 | 0.012352 | 0.004094 | 0.000587 | NaN | NaN | NaN | NaN | 0.024242 | 0.012076 | ... | 0.007145 | 2.859681 | 0.043221 | 0.033203 | 0.037963 | 0.038435 | 0.036747 | 0.035768 | 0.037295 | 0.002154 |
| min | 5.482249 | 0.029599 | 0.046318 | 0.001000 | NaN | NaN | NaN | NaN | 0.509375 | 0.537500 | ... | 0.013333 | 1.000000 | 0.577959 | 0.614277 | 0.601127 | 0.589487 | 0.595119 | 0.586984 | 0.594159 | 0.005928 |
| 25% | 5.503617 | 0.035700 | 0.048942 | 0.001287 | NaN | NaN | NaN | NaN | 0.536719 | 0.546875 | ... | 0.023194 | 3.000000 | 0.604728 | 0.620695 | 0.605980 | 0.609043 | 0.614205 | 0.600438 | 0.610121 | 0.006162 |
| 50% | 5.510989 | 0.046362 | 0.049568 | 0.002042 | NaN | NaN | NaN | NaN | 0.559375 | 0.556250 | ... | 0.026863 | 5.500000 | 0.636193 | 0.649969 | 0.626800 | 0.640175 | 0.638611 | 0.635482 | 0.637872 | 0.006775 |
| 75% | 5.516988 | 0.055686 | 0.054025 | 0.002280 | NaN | NaN | NaN | NaN | 0.562500 | 0.562500 | ... | 0.029201 | 7.000000 | 0.657796 | 0.668441 | 0.648716 | 0.656133 | 0.663798 | 0.652378 | 0.657877 | 0.008518 |
| max | 5.616039 | 0.065091 | 0.058315 | 0.002494 | NaN | NaN | NaN | NaN | 0.578125 | 0.568750 | ... | 0.035902 | 9.000000 | 0.697558 | 0.701315 | 0.700063 | 0.695244 | 0.693992 | 0.681477 | 0.694941 | 0.011474 |
11 rows × 25 columns
####################################################
# TEST RESULTS FOR BEST FOUND GRID SEARCH
####################################################
# Configure global parameters for all experiments
subject_ids_to_test = ["B", "C", "E"] # Subjects with three recordings
start_offset = -1 # One second before visual queue
end_offset = 1 # One second after visual queue
baseline = (None, 0) # Baseline correction using data before the visual queue
filter_lower_bound = 2 # Filter out any frequency below this
filter_upper_bound = 32 # Filter out any frequency above this
best_found_csp_components = [10, 10 , 10]
best_found_svm_kernel = ["rbf", "rbf", "sigmoid"]
best_found_svm_c = [10, 1, 1]
best_found_svm_gamma = [0.01, "auto", "scale"]
# Loop over all found results
for i in range(len(subject_ids_to_test)):
print("\n\n")
print("####################################################")
print(f"# TEST RESULTS FOR SUBJECT {subject_ids_to_test[i]}")
print("####################################################")
print("\n\n")
################# TRAINING DATA #################
with io.capture_output():
# Get all training data (all but last session of participant)
mne_raws = CLA_dataset.get_all_but_last_raw_mne_data_for_subject(subject_id= subject_ids_to_test[i])
# Combine training data into singular mne raw
mne_raw = mne.concatenate_raws(mne_raws)
# Get epochs for all those MNE raws (all training sessions)
mne_epochs = CLA_dataset.get_usefull_epochs_from_raw(mne_raw,
start_offset= start_offset,
end_offset= end_offset,
baseline= baseline)
# Only keep epochs from the MI tasks
mne_epochs = mne_epochs['task/neutral', 'task/left', 'task/right']
# Load epochs into memory
mne_epochs.load_data()
# Get the labels
y_train = mne_epochs.events[:, -1]
# Use a fixed filter
mne_epochs.filter(l_freq= filter_lower_bound,
h_freq= filter_upper_bound,
picks= "all",
phase= "minimum",
fir_window= "blackman",
fir_design= "firwin",
pad= 'median',
n_jobs= -1,
verbose= False)
# Get a half second window
X_train = mne_epochs.get_data(tmin= 0.1, tmax= 0.6)
# Delete resedual vars for training data
del mne_raws
del mne_raw
del mne_epochs
################# TESTING DATA #################
with io.capture_output():
# Get test data
mne_raw = CLA_dataset.get_last_raw_mne_data_for_subject(subject_id= subject_ids_to_test[i])
# Get epochs for test MNE raw
mne_epochs = CLA_dataset.get_usefull_epochs_from_raw(mne_raw,
start_offset= start_offset,
end_offset= end_offset,
baseline= baseline)
# Only keep epochs from the MI tasks
mne_epochs = mne_epochs['task/neutral', 'task/left', 'task/right']
# Load epochs into memory
mne_epochs.load_data()
# Get the labels
y_test = mne_epochs.events[:, -1]
# Use a fixed filter
mne_epochs.filter(l_freq= filter_lower_bound,
h_freq= filter_upper_bound,
picks= "all",
phase= "minimum",
fir_window= "blackman",
fir_design= "firwin",
pad= 'median',
n_jobs= -1,
verbose= False)
# Get a half second window
X_test = mne_epochs.get_data(tmin= 0.1, tmax= 0.6)
# Delete resedual vars for training data
del mne_raw
del mne_epochs
################# FIT AND PREDICT #################
# Make the classifier
csp = CSP(norm_trace=False,
component_order="mutual_info",
cov_est= "epoch",
n_components= best_found_csp_components[i])
svm = SVC(kernel= best_found_svm_kernel[i],
C= best_found_svm_c[i],
gamma= best_found_svm_gamma[i])
# Configure the pipeline
pipeline = Pipeline([('CSP', csp), ('SVM', svm)])
# Fit the pipeline
with io.capture_output():
pipeline.fit(X_train, y_train)
# Get accuracy for single fit
y_pred = pipeline.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
# Print accuracy results and CM
print(f"Test accuracy for subject {subject_ids_to_test[i]}: {np.round(accuracy, 4)}")
ConfusionMatrixDisplay.from_predictions(y_true= y_test, y_pred= y_pred)
plt.show()
# plot CSP patterns estimated on train data for visualization
pipeline['CSP'].plot_patterns(CLA_dataset.get_last_raw_mne_data_for_subject(subject_id= subject_ids_to_test[i]).info, ch_type='eeg', units='Patterns (AU)', size=1.5)
plt.show()
# Remove unsused variables
del subject_ids_to_test
del best_found_csp_components
del best_found_svm_kernel
del best_found_svm_c
del best_found_svm_gamma
del i
del X_test
del y_test
del X_train
del y_train
del csp
del svm
del pipeline
del y_pred
del accuracy
del start_offset
del end_offset
del baseline
del filter_lower_bound
del filter_upper_bound
#################################################### # TEST RESULTS FOR SUBJECT B #################################################### Test accuracy for subject B: 0.4677
Reading 0 ... 667799 = 0.000 ... 3338.995 secs...
#################################################### # TEST RESULTS FOR SUBJECT C #################################################### Test accuracy for subject C: 0.3754
Reading 0 ... 669399 = 0.000 ... 3346.995 secs...
#################################################### # TEST RESULTS FOR SUBJECT E #################################################### Test accuracy for subject E: 0.3895
Reading 0 ... 666999 = 0.000 ... 3334.995 secs...
This experiment works as follows:
####################################################
# GRID SEARCHING BEST PIPELINE FOR EACH SUBJECT
####################################################
# Configure global parameters for all experiments
subject_ids_to_test = ["B", "C", "E"] # Subjects with three recordings
start_offset = -1 # One second before visual queue
end_offset = 1 # One second after visual queue
baseline = (None, 0) # Baseline correction using data before the visual queue
filter_lower_bound = 2 # Filter out any frequency below this
filter_upper_bound = 32 # Filter out any frequency above this
do_experiment = False # Long experiment disabled per default
if do_experiment:
# Loop over all subjects and perform the grid search for finding the best parameters
for subject_id in subject_ids_to_test:
# Get all training data (all but last session of participant)
mne_raws= CLA_dataset.get_all_but_last_raw_mne_data_for_subject(subject_id= subject_id)
# Combine training data into singular mne raw
mne_raw = mne.concatenate_raws(mne_raws)
# Delete all raws since concat changes them
del mne_raws
# Get epochs for all those MNE raws (all training sessions)
mne_epochs = CLA_dataset.get_usefull_epochs_from_raw(mne_raw,
start_offset= start_offset,
end_offset= end_offset,
baseline= baseline)
# Only keep epochs from the MI tasks
mne_epochs = mne_epochs['task/neutral', 'task/left', 'task/right']
# Load epochs into memory
mne_epochs.load_data()
# Get the labels
labels = mne_epochs.events[:, -1]
# Use a fixed filter
mne_epochs.filter(l_freq= filter_lower_bound,
h_freq= filter_upper_bound,
picks= "all",
phase= "minimum",
fir_window= "blackman",
fir_design= "firwin",
pad= 'median',
n_jobs= -1,
verbose= False)
# Get a half second window
mne_epochs_data = mne_epochs.get_data(tmin= 0.1, tmax= 0.6)
# Configure the pipeline components by specifying the default parameters
csp = CSP(norm_trace=False,
component_order="mutual_info",
cov_est= "epoch")
rf = RandomForestClassifier(bootstrap= True,
criterion= "gini")
# Configure the pipeline
pipeline = Pipeline([('CSP', csp), ('RF', rf)])
# Configure cross validation to use, more splits then before since we have more data
cv = StratifiedKFold(n_splits= 6,
shuffle= True,
random_state= 2022)
# Configure the hyperparameters to test
# NOTE: these are somewhat limited due to limitedd computational resources
param_grid = [{"CSP__n_components": [4, 6, 10],
"RF__n_estimators": [10, 50, 100, 250, 500],
"RF__max_depth": [None, 3, 10],
"RF__min_samples_split": [2, 5, 10],
"RF__max_features": ["sqrt", "log2", "None", 0.2, 0.4, 0.6]}]
# Configure the grid search
grid_search = GridSearchCV(estimator= pipeline,
param_grid= param_grid,
scoring= "accuracy",
n_jobs= -1,
refit= False, # We will do this manually
cv= cv,
verbose= 10,
return_train_score= True)
# Do the grid search on the training data
grid_search.fit(X= mne_epochs_data, y= labels)
# Store the results of the grid search
with open(f"saved_variables/2/samesubject_differentsession/subject{subject_id}/gridsearch_csprf.pickle", 'wb') as file:
pickle.dump(grid_search, file)
# Delete vars after singular experiment
del mne_raw
del mne_epochs
del mne_epochs_data
del csp
del rf
del pipeline
del labels
del cv
del file
del grid_search
del param_grid
# Delete vars after all experiments
del subject_id
# Del global vars
del subject_ids_to_test
del filter_lower_bound
del filter_upper_bound
del baseline
del do_experiment
del end_offset
del start_offset
The CV results are based on the training set alone and thus only look at the first two sessions. The test result is for a new, unseen session and thus scores are expected to differ.
| Subject | CSP + RF: cross validation accuracy | CSP + RF: test split accuracy | Config |
|---|---|---|---|
| B | 0.4489 +- 0.0351 | 0.4406 | 10 CSP components | RF with max depth 3, 0.4 features, 10 min sample split, 500 estimators |
| C | 0.8198 +- 0.0198 | 0.3462 | 10 CSP components | RF with max depth None, 0.2 features, 2 min sample split, 50 estimators |
| E | 0.5770 +- 0.0290 | 0.4911 | 10 CSP components | RF with max depth 10, 0.4 features, 10 min sample split, 250 estimators |
####################################################
# GRID SEARCH RESULTS
####################################################
# Configure global parameters for all experiments
subject_ids_to_test = ["B", "C", "E"] # Subjects with three recordings
# Loop over all found results
for subject_id in subject_ids_to_test:
print("\n\n")
print("####################################################")
print(f"# GRID SEARCH RESULTS FOR SUBJECT {subject_id}")
print("####################################################")
print("\n\n")
# Open from file
with open(f"saved_variables/2/samesubject_differentsession/subject{subject_id}/gridsearch_csprf.pickle", 'rb') as f:
grid_search = pickle.load(f)
# Print the results
print(f"Best estimator has accuracy of {np.round(grid_search.best_score_, 4)} +- {np.round(grid_search.cv_results_['std_test_score'][grid_search.best_index_], 4)} with the following parameters")
print(grid_search.best_params_)
# Get grid search results
grid_search_results = pd.DataFrame(grid_search.cv_results_)
# Keep relevant columns and sort on rank
grid_search_results.drop(labels='params', axis=1, inplace= True)
grid_search_results.sort_values(by=['rank_test_score'], inplace=True)
# Display grid search resulst
print("\n\n Top 10 grid search results: ")
display(grid_search_results.head(10))
print("\n\n Worst 10 grid search results: ")
display(grid_search_results.tail(10))
# Display some statistics
print(f"\n\nIn total there are {len(grid_search_results)} different configurations tested.")
max_score = grid_search_results['mean_test_score'].max()
print(f"The best mean test score is {round(max_score, 4)}")
shared_first_place_count = len(grid_search_results[grid_search_results['mean_test_score'].between(max_score, max_score)])
print(f"There are {shared_first_place_count} configurations with this maximum score")
close_first_place_count = len(grid_search_results[grid_search_results['mean_test_score'].between(max_score-0.02, max_score)])
print(f"There are {close_first_place_count} configurations within 0.02 of this maximum score")
# Display statistics for best classifiers
print("\n\nThe describe of the configurations within 0.02 of this maximum score is as follows:")
display(grid_search_results[grid_search_results['mean_test_score'].between(max_score-0.02, max_score)].describe(include="all"))
# Remove unsused variables
del f
del grid_search
del max_score
del shared_first_place_count
del close_first_place_count
del grid_search_results
del subject_ids_to_test
del subject_id
####################################################
# GRID SEARCH RESULTS FOR SUBJECT B
####################################################
Best estimator has accuracy of 0.4489 +- 0.0351 with the following parameters
{'CSP__n_components': 10, 'RF__max_depth': 3, 'RF__max_features': 0.4, 'RF__min_samples_split': 10, 'RF__n_estimators': 500}
Top 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_RF__max_depth | param_RF__max_features | param_RF__min_samples_split | param_RF__n_estimators | split0_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 704 | 6.288822 | 0.033622 | 0.055982 | 0.004508 | 10 | 3 | 0.4 | 10 | 500 | 0.475000 | ... | 0.035137 | 1 | 0.540050 | 0.523154 | 0.516270 | 0.534418 | 0.519074 | 0.512195 | 0.524194 | 0.009916 |
| 647 | 5.473083 | 0.066109 | 0.014329 | 0.000471 | 10 | 3 | log2 | 2 | 100 | 0.478125 | ... | 0.035396 | 2 | 0.538798 | 0.543179 | 0.517522 | 0.546934 | 0.527205 | 0.522827 | 0.532744 | 0.010860 |
| 698 | 5.724003 | 0.024250 | 0.028991 | 0.000577 | 10 | 3 | 0.4 | 5 | 250 | 0.459375 | ... | 0.031401 | 3 | 0.537547 | 0.520025 | 0.511890 | 0.537547 | 0.511570 | 0.519700 | 0.523046 | 0.010778 |
| 699 | 6.274327 | 0.054617 | 0.053983 | 0.000577 | 10 | 3 | 0.4 | 5 | 500 | 0.481250 | ... | 0.030516 | 4 | 0.540050 | 0.516270 | 0.513767 | 0.534418 | 0.519074 | 0.517824 | 0.523567 | 0.009932 |
| 649 | 6.198851 | 0.051525 | 0.055316 | 0.001105 | 10 | 3 | log2 | 2 | 500 | 0.478125 | ... | 0.034705 | 5 | 0.550063 | 0.533792 | 0.513141 | 0.544431 | 0.529706 | 0.525954 | 0.532848 | 0.012091 |
| 690 | 5.215166 | 0.029022 | 0.005165 | 0.000372 | 10 | 3 | 0.4 | 2 | 10 | 0.481250 | ... | 0.033004 | 6 | 0.521902 | 0.506884 | 0.505632 | 0.518773 | 0.503440 | 0.507192 | 0.510637 | 0.007023 |
| 648 | 5.705676 | 0.045883 | 0.031156 | 0.003974 | 10 | 3 | log2 | 2 | 250 | 0.471875 | ... | 0.032963 | 7 | 0.549437 | 0.535044 | 0.513767 | 0.532541 | 0.529081 | 0.523452 | 0.530554 | 0.010917 |
| 659 | 6.189020 | 0.022069 | 0.054316 | 0.000471 | 10 | 3 | log2 | 10 | 500 | 0.481250 | ... | 0.036880 | 8 | 0.543179 | 0.537547 | 0.511264 | 0.540050 | 0.534084 | 0.520325 | 0.531075 | 0.011446 |
| 693 | 5.727002 | 0.038843 | 0.029157 | 0.000687 | 10 | 3 | 0.4 | 2 | 250 | 0.471875 | ... | 0.035690 | 9 | 0.537547 | 0.526909 | 0.516896 | 0.533792 | 0.518449 | 0.514697 | 0.524715 | 0.008686 |
| 702 | 5.392275 | 0.025105 | 0.014329 | 0.000471 | 10 | 3 | 0.4 | 10 | 100 | 0.456250 | ... | 0.033377 | 10 | 0.541927 | 0.525031 | 0.516270 | 0.533166 | 0.527205 | 0.510944 | 0.525757 | 0.010234 |
10 rows × 26 columns
Worst 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_RF__max_depth | param_RF__max_features | param_RF__min_samples_split | param_RF__n_estimators | split0_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 489 | 5.310468 | 0.040122 | 0.0 | 0.0 | 6 | 10 | None | 5 | 500 | NaN | ... | NaN | 801 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 481 | 5.201670 | 0.050766 | 0.0 | 0.0 | 6 | 10 | None | 2 | 50 | NaN | ... | NaN | 802 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 403 | 5.258152 | 0.033508 | 0.0 | 0.0 | 6 | 3 | None | 10 | 250 | NaN | ... | NaN | 803 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 402 | 5.203669 | 0.048178 | 0.0 | 0.0 | 6 | 3 | None | 10 | 100 | NaN | ... | NaN | 804 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 401 | 5.189840 | 0.037868 | 0.0 | 0.0 | 6 | 3 | None | 10 | 50 | NaN | ... | NaN | 805 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 400 | 5.218998 | 0.047122 | 0.0 | 0.0 | 6 | 3 | None | 10 | 10 | NaN | ... | NaN | 806 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 399 | 5.432763 | 0.083974 | 0.0 | 0.0 | 6 | 3 | None | 5 | 500 | NaN | ... | NaN | 807 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 398 | 5.278645 | 0.034963 | 0.0 | 0.0 | 6 | 3 | None | 5 | 250 | NaN | ... | NaN | 808 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 480 | 5.166847 | 0.039080 | 0.0 | 0.0 | 6 | 10 | None | 2 | 10 | NaN | ... | NaN | 809 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 404 | 5.323298 | 0.029294 | 0.0 | 0.0 | 6 | 3 | None | 10 | 500 | NaN | ... | NaN | 810 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
10 rows × 26 columns
In total there are 810 different configurations tested. The best mean test score is 0.4489 There are 1 configurations with this maximum score There are 146 configurations within 0.02 of this maximum score The describe of the configurations within 0.02 of this maximum score is as follows:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_RF__max_depth | param_RF__max_features | param_RF__min_samples_split | param_RF__n_estimators | split0_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 146.000000 | 146.000000 | 146.000000 | 1.460000e+02 | 146.0 | 137.0 | 146 | 146.0 | 146.0 | 146.000000 | ... | 146.000000 | 146.000000 | 146.000000 | 146.000000 | 146.000000 | 146.000000 | 146.000000 | 146.000000 | 146.000000 | 146.000000 |
| unique | NaN | NaN | NaN | NaN | 2.0 | 2.0 | 5 | 3.0 | 5.0 | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| top | NaN | NaN | NaN | NaN | 10.0 | 3.0 | sqrt | 10.0 | 500.0 | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| freq | NaN | NaN | NaN | NaN | 99.0 | 110.0 | 33 | 54.0 | 40.0 | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| mean | 5.724071 | 0.043302 | 0.029333 | 1.110624e-03 | NaN | NaN | NaN | NaN | NaN | 0.457663 | ... | 0.028474 | 73.452055 | 0.629155 | 0.628470 | 0.619550 | 0.623391 | 0.623022 | 0.622979 | 0.624428 | 0.008888 |
| std | 0.474131 | 0.011790 | 0.020131 | 1.325608e-03 | NaN | NaN | NaN | NaN | NaN | 0.012981 | ... | 0.006108 | 42.313621 | 0.182232 | 0.185578 | 0.193853 | 0.193275 | 0.191195 | 0.191815 | 0.189502 | 0.003045 |
| min | 5.184343 | 0.018081 | 0.004999 | 1.777067e-07 | NaN | NaN | NaN | NaN | NaN | 0.428125 | ... | 0.014030 | 1.000000 | 0.492491 | 0.493742 | 0.493742 | 0.472466 | 0.492183 | 0.490932 | 0.494264 | 0.000000 |
| 25% | 5.345166 | 0.034625 | 0.013204 | 4.710965e-04 | NaN | NaN | NaN | NaN | NaN | 0.450000 | ... | 0.024821 | 37.250000 | 0.512672 | 0.517522 | 0.505632 | 0.492491 | 0.507817 | 0.507817 | 0.506882 | 0.007276 |
| 50% | 5.605790 | 0.042823 | 0.028574 | 5.771021e-04 | NaN | NaN | NaN | NaN | NaN | 0.456250 | ... | 0.028377 | 73.500000 | 0.538486 | 0.527534 | 0.513454 | 0.531602 | 0.520325 | 0.517824 | 0.524662 | 0.009198 |
| 75% | 6.027031 | 0.052105 | 0.053608 | 1.105207e-03 | NaN | NaN | NaN | NaN | NaN | 0.465625 | ... | 0.033295 | 109.750000 | 0.561170 | 0.548811 | 0.533479 | 0.551783 | 0.544246 | 0.542839 | 0.541997 | 0.010892 |
| max | 7.354648 | 0.077664 | 0.071810 | 8.231705e-03 | NaN | NaN | NaN | NaN | NaN | 0.481250 | ... | 0.040330 | 146.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 0.016772 |
11 rows × 26 columns
####################################################
# GRID SEARCH RESULTS FOR SUBJECT C
####################################################
Best estimator has accuracy of 0.8198 +- 0.0198 with the following parameters
{'CSP__n_components': 10, 'RF__max_depth': None, 'RF__max_features': 0.2, 'RF__min_samples_split': 2, 'RF__n_estimators': 50}
Top 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_RF__max_depth | param_RF__max_features | param_RF__min_samples_split | param_RF__n_estimators | split0_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 586 | 5.468418 | 0.071404 | 0.014496 | 0.006549 | 10 | None | 0.2 | 2 | 50 | 0.818750 | ... | 0.019819 | 1 | 0.999375 | 1.000000 | 1.000000 | 0.999375 | 1.000000 | 0.999375 | 0.999687 | 0.000312 |
| 558 | 6.179690 | 0.057260 | 0.036989 | 0.000817 | 10 | None | log2 | 2 | 250 | 0.793750 | ... | 0.026125 | 2 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 0.000000 |
| 587 | 5.598710 | 0.043140 | 0.017828 | 0.001067 | 10 | None | 0.2 | 2 | 100 | 0.806250 | ... | 0.018641 | 3 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 0.000000 |
| 768 | 5.996082 | 0.053971 | 0.037821 | 0.004596 | 10 | 10 | 0.2 | 2 | 250 | 0.803125 | ... | 0.019348 | 4 | 0.976250 | 0.981250 | 0.983750 | 0.981250 | 0.980625 | 0.983125 | 0.981042 | 0.002412 |
| 592 | 5.609207 | 0.054281 | 0.018494 | 0.000500 | 10 | None | 0.2 | 5 | 100 | 0.809375 | ... | 0.019094 | 4 | 0.998125 | 0.995625 | 0.995000 | 0.996875 | 0.995000 | 0.994375 | 0.995833 | 0.001284 |
| 615 | 5.382945 | 0.069358 | 0.006498 | 0.000499 | 10 | None | 0.6 | 2 | 10 | 0.812500 | ... | 0.025430 | 6 | 0.985625 | 0.988125 | 0.988750 | 0.991250 | 0.985000 | 0.991250 | 0.988333 | 0.002438 |
| 567 | 5.641696 | 0.029298 | 0.017495 | 0.000957 | 10 | None | log2 | 10 | 100 | 0.803125 | ... | 0.024122 | 7 | 0.960625 | 0.957500 | 0.962500 | 0.957500 | 0.960000 | 0.959375 | 0.959583 | 0.001755 |
| 774 | 6.660537 | 0.057995 | 0.066645 | 0.001598 | 10 | 10 | 0.2 | 5 | 500 | 0.800000 | ... | 0.019073 | 7 | 0.970000 | 0.964375 | 0.971875 | 0.966875 | 0.970000 | 0.965625 | 0.968125 | 0.002676 |
| 781 | 5.485912 | 0.052559 | 0.011497 | 0.001258 | 10 | 10 | 0.4 | 2 | 50 | 0.796875 | ... | 0.024788 | 9 | 0.979375 | 0.981250 | 0.983125 | 0.985000 | 0.983750 | 0.980000 | 0.982083 | 0.002031 |
| 738 | 6.136705 | 0.030369 | 0.036155 | 0.002339 | 10 | 10 | log2 | 2 | 250 | 0.796875 | ... | 0.022146 | 10 | 0.982500 | 0.985625 | 0.986875 | 0.983125 | 0.985000 | 0.983750 | 0.984479 | 0.001506 |
10 rows × 26 columns
Worst 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_RF__max_depth | param_RF__max_features | param_RF__min_samples_split | param_RF__n_estimators | split0_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 489 | 5.441760 | 0.072491 | 0.0 | 0.0 | 6 | 10 | None | 5 | 500 | NaN | ... | NaN | 801 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 481 | 5.347956 | 0.050206 | 0.0 | 0.0 | 6 | 10 | None | 2 | 50 | NaN | ... | NaN | 802 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 403 | 5.352288 | 0.050912 | 0.0 | 0.0 | 6 | 3 | None | 10 | 250 | NaN | ... | NaN | 803 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 402 | 5.313301 | 0.061456 | 0.0 | 0.0 | 6 | 3 | None | 10 | 100 | NaN | ... | NaN | 804 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 401 | 5.283144 | 0.054436 | 0.0 | 0.0 | 6 | 3 | None | 10 | 50 | NaN | ... | NaN | 805 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 400 | 5.299139 | 0.065402 | 0.0 | 0.0 | 6 | 3 | None | 10 | 10 | NaN | ... | NaN | 806 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 399 | 5.402773 | 0.034252 | 0.0 | 0.0 | 6 | 3 | None | 5 | 500 | NaN | ... | NaN | 807 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 398 | 5.354954 | 0.058481 | 0.0 | 0.0 | 6 | 3 | None | 5 | 250 | NaN | ... | NaN | 808 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 480 | 5.284477 | 0.053404 | 0.0 | 0.0 | 6 | 10 | None | 2 | 10 | NaN | ... | NaN | 809 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 404 | 5.430597 | 0.048202 | 0.0 | 0.0 | 6 | 3 | None | 10 | 500 | NaN | ... | NaN | 810 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
10 rows × 26 columns
In total there are 810 different configurations tested. The best mean test score is 0.8198 There are 1 configurations with this maximum score There are 129 configurations within 0.02 of this maximum score The describe of the configurations within 0.02 of this maximum score is as follows:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_RF__max_depth | param_RF__max_features | param_RF__min_samples_split | param_RF__n_estimators | split0_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 129.000000 | 129.000000 | 129.000000 | 1.290000e+02 | 129.0 | 66.0 | 129.0 | 129.0 | 129.0 | 129.000000 | ... | 129.000000 | 129.000000 | 129.000000 | 129.000000 | 129.000000 | 129.000000 | 129.000000 | 129.000000 | 129.000000 | 129.000000 |
| unique | NaN | NaN | NaN | NaN | 2.0 | 1.0 | 5.0 | 3.0 | 5.0 | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| top | NaN | NaN | NaN | NaN | 10.0 | 10.0 | 0.6 | 10.0 | 250.0 | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| freq | NaN | NaN | NaN | NaN | 128.0 | 66.0 | 27.0 | 44.0 | 31.0 | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| mean | 6.113911 | 0.055613 | 0.031845 | 1.764666e-03 | NaN | NaN | NaN | NaN | NaN | 0.796342 | ... | 0.020744 | 63.620155 | 0.972611 | 0.971773 | 0.975426 | 0.973609 | 0.973324 | 0.973372 | 0.973353 | 0.002294 |
| std | 0.722545 | 0.021618 | 0.022401 | 1.986289e-03 | NaN | NaN | NaN | NaN | NaN | 0.007351 | ... | 0.003154 | 37.547926 | 0.021050 | 0.023346 | 0.020466 | 0.021426 | 0.020719 | 0.021207 | 0.021251 | 0.001366 |
| min | 5.327130 | 0.024757 | 0.005665 | 4.298152e-07 | NaN | NaN | NaN | NaN | NaN | 0.771875 | ... | 0.011219 | 1.000000 | 0.921875 | 0.921875 | 0.923125 | 0.924375 | 0.927500 | 0.928750 | 0.925312 | 0.000000 |
| 25% | 5.567220 | 0.044128 | 0.011497 | 6.870479e-04 | NaN | NaN | NaN | NaN | NaN | 0.790625 | ... | 0.018786 | 31.000000 | 0.961250 | 0.956250 | 0.962500 | 0.960000 | 0.960000 | 0.958750 | 0.960104 | 0.001284 |
| 50% | 5.866124 | 0.052559 | 0.018494 | 9.996495e-04 | NaN | NaN | NaN | NaN | NaN | 0.796875 | ... | 0.020963 | 60.000000 | 0.970000 | 0.972500 | 0.977500 | 0.975000 | 0.972500 | 0.973125 | 0.973542 | 0.002412 |
| 75% | 6.632546 | 0.063156 | 0.039488 | 1.885214e-03 | NaN | NaN | NaN | NaN | NaN | 0.800000 | ... | 0.022535 | 94.000000 | 0.995000 | 0.995625 | 0.995000 | 0.995000 | 0.994375 | 0.993750 | 0.994896 | 0.003256 |
| max | 8.312841 | 0.239322 | 0.076809 | 1.085416e-02 | NaN | NaN | NaN | NaN | NaN | 0.818750 | ... | 0.028565 | 128.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 0.005349 |
11 rows × 26 columns
####################################################
# GRID SEARCH RESULTS FOR SUBJECT E
####################################################
Best estimator has accuracy of 0.577 +- 0.029 with the following parameters
{'CSP__n_components': 10, 'RF__max_depth': 10, 'RF__max_features': 0.4, 'RF__min_samples_split': 10, 'RF__n_estimators': 250}
Top 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_RF__max_depth | param_RF__max_features | param_RF__min_samples_split | param_RF__n_estimators | split0_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 793 | 6.246836 | 0.058861 | 0.034156 | 0.000687 | 10 | 10 | 0.4 | 10 | 250 | 0.568750 | ... | 0.029031 | 1 | 0.877270 | 0.869130 | 0.879148 | 0.870463 | 0.882353 | 0.862954 | 0.873553 | 0.006634 |
| 808 | 6.637710 | 0.042254 | 0.034489 | 0.000500 | 10 | 10 | 0.6 | 10 | 250 | 0.565625 | ... | 0.026321 | 2 | 0.871634 | 0.867877 | 0.880401 | 0.866083 | 0.885482 | 0.867334 | 0.873135 | 0.007286 |
| 724 | 6.972770 | 0.058687 | 0.064813 | 0.000897 | 10 | 10 | sqrt | 2 | 500 | 0.559375 | ... | 0.025288 | 3 | 0.947401 | 0.930495 | 0.940513 | 0.935544 | 0.944931 | 0.927409 | 0.937716 | 0.007266 |
| 743 | 6.102215 | 0.074741 | 0.035822 | 0.000687 | 10 | 10 | log2 | 5 | 250 | 0.559375 | ... | 0.022637 | 4 | 0.929242 | 0.908579 | 0.927990 | 0.913642 | 0.918648 | 0.912390 | 0.918415 | 0.007797 |
| 732 | 5.587380 | 0.017596 | 0.016161 | 0.000372 | 10 | 10 | sqrt | 10 | 100 | 0.546875 | ... | 0.027574 | 5 | 0.876644 | 0.876018 | 0.871634 | 0.867334 | 0.884230 | 0.856696 | 0.872093 | 0.008596 |
| 554 | 7.116558 | 0.025470 | 0.068145 | 0.002192 | 10 | None | sqrt | 10 | 500 | 0.562500 | ... | 0.018288 | 6 | 0.964934 | 0.973075 | 0.973075 | 0.966208 | 0.974343 | 0.974343 | 0.970996 | 0.003888 |
| 723 | 6.103382 | 0.068446 | 0.034822 | 0.000687 | 10 | 10 | sqrt | 2 | 250 | 0.571875 | ... | 0.031380 | 6 | 0.948028 | 0.932999 | 0.942392 | 0.939925 | 0.944305 | 0.931790 | 0.939906 | 0.005846 |
| 731 | 5.417767 | 0.070397 | 0.010497 | 0.000764 | 10 | 10 | sqrt | 10 | 50 | 0.584375 | ... | 0.019176 | 8 | 0.865999 | 0.854728 | 0.867877 | 0.855444 | 0.876095 | 0.854819 | 0.862494 | 0.008116 |
| 612 | 5.674352 | 0.039028 | 0.016995 | 0.000577 | 10 | None | 0.4 | 10 | 100 | 0.550000 | ... | 0.027673 | 9 | 0.968065 | 0.962430 | 0.966813 | 0.963079 | 0.968711 | 0.966208 | 0.965884 | 0.002364 |
| 597 | 5.523401 | 0.029758 | 0.017161 | 0.000897 | 10 | None | 0.2 | 10 | 100 | 0.540625 | ... | 0.019971 | 10 | 0.967439 | 0.959925 | 0.961177 | 0.955569 | 0.964330 | 0.958073 | 0.961086 | 0.003911 |
10 rows × 26 columns
Worst 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_RF__max_depth | param_RF__max_features | param_RF__min_samples_split | param_RF__n_estimators | split0_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 481 | 5.237825 | 0.057107 | 0.0 | 0.0 | 6 | 10 | None | 2 | 50 | NaN | ... | NaN | 801 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 482 | 5.240824 | 0.041455 | 0.0 | 0.0 | 6 | 10 | None | 2 | 100 | NaN | ... | NaN | 802 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 483 | 5.305137 | 0.061377 | 0.0 | 0.0 | 6 | 10 | None | 2 | 250 | NaN | ... | NaN | 803 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 484 | 5.373782 | 0.042392 | 0.0 | 0.0 | 6 | 10 | None | 2 | 500 | NaN | ... | NaN | 804 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 485 | 5.220998 | 0.057545 | 0.0 | 0.0 | 6 | 10 | None | 5 | 10 | NaN | ... | NaN | 805 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 486 | 5.243157 | 0.037571 | 0.0 | 0.0 | 6 | 10 | None | 5 | 50 | NaN | ... | NaN | 806 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 487 | 5.245656 | 0.056370 | 0.0 | 0.0 | 6 | 10 | None | 5 | 100 | NaN | ... | NaN | 807 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 488 | 5.304970 | 0.031309 | 0.0 | 0.0 | 6 | 10 | None | 5 | 250 | NaN | ... | NaN | 808 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 398 | 5.309136 | 0.035426 | 0.0 | 0.0 | 6 | 3 | None | 5 | 250 | NaN | ... | NaN | 809 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 404 | 5.374615 | 0.049983 | 0.0 | 0.0 | 6 | 3 | None | 10 | 500 | NaN | ... | NaN | 810 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
10 rows × 26 columns
In total there are 810 different configurations tested. The best mean test score is 0.577 There are 1 configurations with this maximum score There are 102 configurations within 0.02 of this maximum score The describe of the configurations within 0.02 of this maximum score is as follows:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_RF__max_depth | param_RF__max_features | param_RF__min_samples_split | param_RF__n_estimators | split0_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 102.000000 | 102.000000 | 102.000000 | 1.020000e+02 | 102.0 | 59.0 | 102 | 102.0 | 102.0 | 102.000000 | ... | 102.000000 | 102.000000 | 102.000000 | 102.000000 | 102.000000 | 102.000000 | 102.000000 | 102.000000 | 102.000000 | 102.000000 |
| unique | NaN | NaN | NaN | NaN | 1.0 | 2.0 | 5 | 3.0 | 4.0 | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| top | NaN | NaN | NaN | NaN | 10.0 | 10.0 | log2 | 10.0 | 500.0 | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| freq | NaN | NaN | NaN | NaN | 102.0 | 58.0 | 23 | 35.0 | 29.0 | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| mean | 6.200793 | 0.049928 | 0.035069 | 1.463454e-03 | NaN | NaN | NaN | NaN | NaN | 0.546691 | ... | 0.025157 | 51.411765 | 0.941060 | 0.935099 | 0.939537 | 0.934569 | 0.942594 | 0.933612 | 0.937745 | 0.004694 |
| std | 0.773970 | 0.013993 | 0.022639 | 1.754671e-03 | NaN | NaN | NaN | NaN | NaN | 0.013870 | ... | 0.004430 | 29.567963 | 0.053956 | 0.055417 | 0.054273 | 0.056511 | 0.052175 | 0.058204 | 0.054939 | 0.003169 |
| min | 5.375115 | 0.017596 | 0.010163 | 3.276750e-07 | NaN | NaN | NaN | NaN | NaN | 0.506250 | ... | 0.013656 | 1.000000 | 0.607389 | 0.636193 | 0.613024 | 0.622653 | 0.626408 | 0.608886 | 0.619092 | 0.000000 |
| 25% | 5.561846 | 0.040449 | 0.016203 | 4.995765e-04 | NaN | NaN | NaN | NaN | NaN | 0.537500 | ... | 0.022459 | 26.250000 | 0.915780 | 0.907952 | 0.914058 | 0.903786 | 0.919274 | 0.908010 | 0.911294 | 0.002060 |
| 50% | 6.030488 | 0.051755 | 0.034656 | 7.636423e-04 | NaN | NaN | NaN | NaN | NaN | 0.546875 | ... | 0.025190 | 51.500000 | 0.944584 | 0.929869 | 0.937383 | 0.934293 | 0.943054 | 0.930538 | 0.936203 | 0.005011 |
| 75% | 6.689236 | 0.058965 | 0.064813 | 1.710333e-03 | NaN | NaN | NaN | NaN | NaN | 0.556250 | ... | 0.027664 | 76.750000 | 0.987790 | 0.987007 | 0.988885 | 0.989831 | 0.990613 | 0.987484 | 0.988524 | 0.007244 |
| max | 8.411477 | 0.084549 | 0.073977 | 1.072664e-02 | NaN | NaN | NaN | NaN | NaN | 0.584375 | ... | 0.036174 | 102.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 0.012191 |
11 rows × 26 columns
####################################################
# TEST RESULTS FOR BEST FOUND GRID SEARCH
####################################################
# Configure global parameters for all experiments
subject_ids_to_test = ["B", "C", "E"] # Subjects with three recordings
start_offset = -1 # One second before visual queue
end_offset = 1 # One second after visual queue
baseline = (None, 0) # Baseline correction using data before the visual queue
filter_lower_bound = 2 # Filter out any frequency below this
filter_upper_bound = 32 # Filter out any frequency above this
best_found_csp_components = [10, 10 , 10]
best_found_rf_depth = [3, None, 10]
best_found_rf_max_featues = [0.4, 0.2, 0.4]
best_found_rf_min_sample = [10, 2, 10]
best_found_rf_n_estimators = [500, 50, 250]
# Loop over all found results
for i in range(len(subject_ids_to_test)):
print("\n\n")
print("####################################################")
print(f"# TEST RESULTS FOR SUBJECT {subject_ids_to_test[i]}")
print("####################################################")
print("\n\n")
################# TRAINING DATA #################
with io.capture_output():
# Get all training data (all but last session of participant)
mne_raws = CLA_dataset.get_all_but_last_raw_mne_data_for_subject(subject_id= subject_ids_to_test[i])
# Combine training data into singular mne raw
mne_raw = mne.concatenate_raws(mne_raws)
# Get epochs for all those MNE raws (all training sessions)
mne_epochs = CLA_dataset.get_usefull_epochs_from_raw(mne_raw,
start_offset= start_offset,
end_offset= end_offset,
baseline= baseline)
# Only keep epochs from the MI tasks
mne_epochs = mne_epochs['task/neutral', 'task/left', 'task/right']
# Load epochs into memory
mne_epochs.load_data()
# Get the labels
y_train = mne_epochs.events[:, -1]
# Use a fixed filter
mne_epochs.filter(l_freq= filter_lower_bound,
h_freq= filter_upper_bound,
picks= "all",
phase= "minimum",
fir_window= "blackman",
fir_design= "firwin",
pad= 'median',
n_jobs= -1,
verbose= False)
# Get a half second window
X_train = mne_epochs.get_data(tmin= 0.1, tmax= 0.6)
# Delete resedual vars for training data
del mne_raws
del mne_raw
del mne_epochs
################# TESTING DATA #################
with io.capture_output():
# Get test data
mne_raw = CLA_dataset.get_last_raw_mne_data_for_subject(subject_id= subject_ids_to_test[i])
# Get epochs for test MNE raw
mne_epochs = CLA_dataset.get_usefull_epochs_from_raw(mne_raw,
start_offset= start_offset,
end_offset= end_offset,
baseline= baseline)
# Only keep epochs from the MI tasks
mne_epochs = mne_epochs['task/neutral', 'task/left', 'task/right']
# Load epochs into memory
mne_epochs.load_data()
# Get the labels
y_test = mne_epochs.events[:, -1]
# Use a fixed filter
mne_epochs.filter(l_freq= filter_lower_bound,
h_freq= filter_upper_bound,
picks= "all",
phase= "minimum",
fir_window= "blackman",
fir_design= "firwin",
pad= 'median',
n_jobs= -1,
verbose= False)
# Get a half second window
X_test = mne_epochs.get_data(tmin= 0.1, tmax= 0.6)
# Delete resedual vars for training data
del mne_raw
del mne_epochs
################# FIT AND PREDICT #################
# Make the classifier
csp = CSP(norm_trace=False,
component_order="mutual_info",
cov_est= "epoch",
n_components= best_found_csp_components[i])
rf = RandomForestClassifier(bootstrap= True,
criterion= "gini",
max_depth= best_found_rf_depth[i],
max_features= best_found_rf_max_featues[i],
min_samples_split= best_found_rf_min_sample[i],
n_estimators= best_found_rf_n_estimators[i])
# Configure the pipeline
pipeline = Pipeline([('CSP', csp), ('RF', rf)])
# Fit the pipeline
with io.capture_output():
pipeline.fit(X_train, y_train)
# Get accuracy for single fit
y_pred = pipeline.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
# Print accuracy results and CM
print(f"Test accuracy for subject {subject_ids_to_test[i]}: {np.round(accuracy, 4)}")
ConfusionMatrixDisplay.from_predictions(y_true= y_test, y_pred= y_pred)
plt.show()
# plot CSP patterns estimated on train data for visualization
pipeline['CSP'].plot_patterns(CLA_dataset.get_last_raw_mne_data_for_subject(subject_id= subject_ids_to_test[i]).info, ch_type='eeg', units='Patterns (AU)', size=1.5)
plt.show()
# Remove unsused variables
del subject_ids_to_test
del best_found_csp_components
del best_found_rf_depth
del best_found_rf_max_featues
del best_found_rf_min_sample
del best_found_rf_n_estimators
del i
del X_test
del y_test
del X_train
del y_train
del csp
del rf
del pipeline
del y_pred
del accuracy
del start_offset
del end_offset
del baseline
del filter_lower_bound
del filter_upper_bound
#################################################### # TEST RESULTS FOR SUBJECT B #################################################### Test accuracy for subject B: 0.4312
Reading 0 ... 667799 = 0.000 ... 3338.995 secs...
#################################################### # TEST RESULTS FOR SUBJECT C #################################################### Test accuracy for subject C: 0.3233
Reading 0 ... 669399 = 0.000 ... 3346.995 secs...
#################################################### # TEST RESULTS FOR SUBJECT E #################################################### Test accuracy for subject E: 0.4901
Reading 0 ... 666999 = 0.000 ... 3334.995 secs...
As discussed in the master's thesis, training and testing a classification system can happen using multiple strategies. Perhaps the hardest task is training a classifier on data from one or more subjects, but using it to classify data from a completely new user. This is the hardest task we'll discuss. This section will train the same classifiers for the same participants as before but by using one participant for testing and the other two for training.
This experiment works as follows:
####################################################
# GRID SEARCHING BEST PIPELINE FOR EACH SUBJECT
####################################################
# Configure global parameters for all experiments
subject_ids_to_test = ["B", "C", "E"] # Subjects with three recordings
start_offset = -1 # One second before visual queue
end_offset = 1 # One second after visual queue
baseline = (None, 0) # Baseline correction using data before the visual queue
filter_lower_bound = 2 # Filter out any frequency below this
filter_upper_bound = 32 # Filter out any frequency above this
do_experiment = False # Long experiment disabled per default
if do_experiment:
# Loop over all subjects and perform the grid search for finding the best parameters
for subject_id in subject_ids_to_test:
###################### PREPARE DATA ######################
with io.capture_output():
# Determine the train subjects
train_subjects = copy.deepcopy(subject_ids_to_test)
train_subjects.remove(subject_id)
mne_raws = []
# Get all training data
for train_subject in train_subjects:
mne_raws.extend(CLA_dataset.get_all_raw_mne_data_for_subject(subject_id= train_subject))
# Combine training data into singular mne raw
mne_raw = mne.concatenate_raws(mne_raws)
# Delete all raws since concat changes them
del mne_raws
# Get epochs for that MNE raw
mne_epochs = CLA_dataset.get_usefull_epochs_from_raw(mne_raw,
start_offset= start_offset,
end_offset= end_offset,
baseline= baseline)
# Only keep epochs from the MI tasks
mne_epochs = mne_epochs['task/neutral', 'task/left', 'task/right']
# Load epochs into memory
mne_epochs.load_data()
# Show training data
print(f"Using data from participants {train_subjects} to train for testing on participant {subject_id}")
# Get the labels
labels = mne_epochs.events[:, -1]
# Use a fixed filter
mne_epochs.filter(l_freq= filter_lower_bound,
h_freq= filter_upper_bound,
picks= "all",
phase= "minimum",
fir_window= "blackman",
fir_design= "firwin",
pad= 'median',
n_jobs= -1,
verbose= False)
# Get a half second window
mne_epochs_data = mne_epochs.get_data(tmin= 0.1, tmax= 0.6)
# Configure the pipeline components by specifying the default parameters
csp = CSP(norm_trace=False,
component_order="mutual_info",
cov_est= "epoch")
lda = LinearDiscriminantAnalysis(shrinkage= None,
priors=[1/3, 1/3, 1/3])
# Configure the pipeline
pipeline = Pipeline([('CSP', csp), ('LDA', lda)])
# Configure cross validation to use, more splits then before since we have more data
cv = StratifiedKFold(n_splits= 10,
shuffle= True,
random_state= 2022)
# Configure the hyperparameters to test
# NOTE: these are somewhat limited due to limitedd computational resources
param_grid = [{"CSP__n_components": [2, 3, 4, 6, 10],
"LDA__solver": ["svd"],
"LDA__tol": [0.0001, 0.00001, 0.001, 0.0004, 0.00007]
},
{"CSP__n_components": [2, 3, 4, 6, 10],
"LDA__solver": ["lsqr" , "eigen"]
}]
# Configure the grid search
grid_search = GridSearchCV(estimator= pipeline,
param_grid= param_grid,
scoring= "accuracy",
n_jobs= -1,
refit= False, # We will do this manually
cv= cv,
verbose= 10,
return_train_score= True)
# Do the grid search on the training data
grid_search.fit(X= mne_epochs_data, y= labels)
# Store the results of the grid search
with open(f"saved_variables/2/newsubject/subject{subject_id}/gridsearch_csplda.pickle", 'wb') as file:
pickle.dump(grid_search, file)
# Delete vars after singular experiment
del mne_raw
del mne_epochs
del mne_epochs_data
del csp
del lda
del pipeline
del labels
del cv
del file
del grid_search
del param_grid
del train_subject
del train_subjects
# Delete vars after all experiments
del subject_id
# Del global vars
del subject_ids_to_test
del filter_lower_bound
del filter_upper_bound
del baseline
del do_experiment
del end_offset
del start_offset
The CV results are based on the training set alone and thus only look at the first two sessions. The test result is for a new, unseen session and thus scores are expected to differ.
| Subject | CSP + LDA: cross validation accuracy | CSP + LDA: test split accuracy | Config |
|---|---|---|---|
| B (Train on C&E) | 0.5662 +- 0.0129 | 0.3961 | 10 CSP components | SVD LDA with 0.0001 tol |
| C (Train on B&E) | 0.4781 +- 0.0185 | 0.4731 | 10 CSP components | lsqr LDA |
| E (Train on B&C) | 0.5567 +- 0.0287 | 0.4098 | 10 CSP components | SVD LDA with 0.0001 tol |
It becomes clear that CSP + LDA struggles with this task as the performance is comparable to random.
####################################################
# GRID SEARCH RESULTS
####################################################
# Configure global parameters for all experiments
subject_ids_to_test = ["B", "C", "E"] # Subjects with three recordings
# Loop over all found results
for subject_id in subject_ids_to_test:
print("\n\n")
print("####################################################")
print(f"# GRID SEARCH RESULTS FOR SUBJECT {subject_id}")
print("####################################################")
print("\n\n")
# Open from file
with open(f"saved_variables/2/newsubject/subject{subject_id}/gridsearch_csplda.pickle", 'rb') as f:
grid_search = pickle.load(f)
# Print the results
print(f"Best estimator has accuracy of {np.round(grid_search.best_score_, 4)} +- {np.round(grid_search.cv_results_['std_test_score'][grid_search.best_index_], 4)} with the following parameters")
print(grid_search.best_params_)
# Get grid search results
grid_search_results = pd.DataFrame(grid_search.cv_results_)
# Keep relevant columns and sort on rank
grid_search_results.drop(labels='params', axis=1, inplace= True)
grid_search_results.sort_values(by=['rank_test_score'], inplace=True)
# Display grid search resulst
print("\n\n Top 10 grid search results: ")
display(grid_search_results.head(10))
print("\n\n Worst 10 grid search results: ")
display(grid_search_results.tail(10))
# Display some statistics
print(f"\n\nIn total there are {len(grid_search_results)} different configurations tested.")
max_score = grid_search_results['mean_test_score'].max()
print(f"The best mean test score is {round(max_score, 4)}")
shared_first_place_count = len(grid_search_results[grid_search_results['mean_test_score'].between(max_score, max_score)])
print(f"There are {shared_first_place_count} configurations with this maximum score")
close_first_place_count = len(grid_search_results[grid_search_results['mean_test_score'].between(max_score-0.02, max_score)])
print(f"There are {close_first_place_count} configurations within 0.02 of this maximum score")
# Display statistics for best classifiers
print("\n\nThe describe of the configurations within 0.02 of this maximum score is as follows:")
display(grid_search_results[grid_search_results['mean_test_score'].between(max_score-0.02, max_score)].describe(include="all"))
# Remove unsused variables
del f
del grid_search
del max_score
del shared_first_place_count
del close_first_place_count
del grid_search_results
del subject_ids_to_test
del subject_id
####################################################
# GRID SEARCH RESULTS FOR SUBJECT B
####################################################
Best estimator has accuracy of 0.5662 +- 0.0129 with the following parameters
{'CSP__n_components': 10, 'LDA__solver': 'svd', 'LDA__tol': 0.0001}
Top 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_LDA__solver | param_LDA__tol | split0_test_score | split1_test_score | split2_test_score | ... | split2_train_score | split3_train_score | split4_train_score | split5_train_score | split6_train_score | split7_train_score | split8_train_score | split9_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 24 | 16.677791 | 0.082748 | 0.008598 | 0.000916 | 10 | svd | 0.00007 | 0.585069 | 0.572174 | 0.587826 | ... | 0.569359 | 0.574961 | 0.580371 | 0.573609 | 0.583655 | 0.590611 | 0.574575 | 0.573802 | 0.577252 | 0.005774 |
| 23 | 16.679991 | 0.071716 | 0.008298 | 0.001004 | 10 | svd | 0.0004 | 0.585069 | 0.572174 | 0.587826 | ... | 0.569359 | 0.574961 | 0.580371 | 0.573609 | 0.583655 | 0.590611 | 0.574575 | 0.573802 | 0.577252 | 0.005774 |
| 22 | 16.687789 | 0.076678 | 0.010497 | 0.004030 | 10 | svd | 0.001 | 0.585069 | 0.572174 | 0.587826 | ... | 0.569359 | 0.574961 | 0.580371 | 0.573609 | 0.583655 | 0.590611 | 0.574575 | 0.573802 | 0.577252 | 0.005774 |
| 21 | 16.646401 | 0.066482 | 0.009497 | 0.001204 | 10 | svd | 0.00001 | 0.585069 | 0.572174 | 0.587826 | ... | 0.569359 | 0.574961 | 0.580371 | 0.573609 | 0.583655 | 0.590611 | 0.574575 | 0.573802 | 0.577252 | 0.005774 |
| 20 | 16.675692 | 0.052500 | 0.009297 | 0.000640 | 10 | svd | 0.0001 | 0.585069 | 0.572174 | 0.587826 | ... | 0.569359 | 0.574961 | 0.580371 | 0.573609 | 0.583655 | 0.590611 | 0.574575 | 0.573802 | 0.577252 | 0.005774 |
| 34 | 16.566027 | 0.203176 | 0.007698 | 0.000641 | 10 | eigen | NaN | 0.585069 | 0.572174 | 0.587826 | ... | 0.569552 | 0.574961 | 0.579791 | 0.573223 | 0.583076 | 0.590611 | 0.574768 | 0.574189 | 0.577214 | 0.005643 |
| 33 | 16.743470 | 0.095126 | 0.007798 | 0.001248 | 10 | lsqr | NaN | 0.585069 | 0.572174 | 0.587826 | ... | 0.569552 | 0.574961 | 0.579791 | 0.573223 | 0.583076 | 0.590611 | 0.574768 | 0.574189 | 0.577214 | 0.005643 |
| 32 | 16.590419 | 0.063734 | 0.005798 | 0.000400 | 6 | eigen | NaN | 0.574653 | 0.554783 | 0.563478 | ... | 0.569745 | 0.556607 | 0.566847 | 0.553903 | 0.559892 | 0.556028 | 0.562403 | 0.558733 | 0.560675 | 0.004632 |
| 31 | 16.604115 | 0.061809 | 0.005499 | 0.000500 | 6 | lsqr | NaN | 0.574653 | 0.554783 | 0.563478 | ... | 0.569745 | 0.556607 | 0.566847 | 0.553903 | 0.559892 | 0.556028 | 0.562403 | 0.558733 | 0.560675 | 0.004632 |
| 19 | 16.583921 | 0.076682 | 0.005898 | 0.000830 | 6 | svd | 0.00007 | 0.574653 | 0.554783 | 0.565217 | ... | 0.570131 | 0.556801 | 0.567233 | 0.554289 | 0.559312 | 0.556221 | 0.562403 | 0.559119 | 0.560791 | 0.004682 |
10 rows × 32 columns
Worst 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_LDA__solver | param_LDA__tol | split0_test_score | split1_test_score | split2_test_score | ... | split2_train_score | split3_train_score | split4_train_score | split5_train_score | split6_train_score | split7_train_score | split8_train_score | split9_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 5 | 16.551432 | 0.045822 | 0.004499 | 0.000500 | 3 | svd | 0.0001 | 0.427083 | 0.436522 | 0.387826 | ... | 0.408617 | 0.410549 | 0.407844 | 0.410355 | 0.423879 | 0.407651 | 0.407457 | 0.410162 | 0.412778 | 0.008450 |
| 27 | 16.527439 | 0.050890 | 0.004199 | 0.000400 | 3 | lsqr | NaN | 0.427083 | 0.436522 | 0.387826 | ... | 0.409003 | 0.410355 | 0.407844 | 0.410162 | 0.423879 | 0.407651 | 0.407264 | 0.410162 | 0.412759 | 0.008394 |
| 28 | 16.506946 | 0.083724 | 0.004099 | 0.000300 | 3 | eigen | NaN | 0.427083 | 0.436522 | 0.387826 | ... | 0.409003 | 0.410355 | 0.407844 | 0.410162 | 0.423879 | 0.407651 | 0.407264 | 0.410162 | 0.412759 | 0.008394 |
| 1 | 16.638651 | 0.077663 | 0.004499 | 0.000500 | 2 | svd | 0.00001 | 0.388889 | 0.408696 | 0.351304 | ... | 0.376932 | 0.380216 | 0.380603 | 0.378478 | 0.382535 | 0.375386 | 0.371522 | 0.382148 | 0.381209 | 0.008991 |
| 2 | 16.578723 | 0.100582 | 0.004399 | 0.000490 | 2 | svd | 0.001 | 0.388889 | 0.408696 | 0.351304 | ... | 0.376932 | 0.380216 | 0.380603 | 0.378478 | 0.382535 | 0.375386 | 0.371522 | 0.382148 | 0.381209 | 0.008991 |
| 3 | 16.556730 | 0.085611 | 0.004199 | 0.000400 | 2 | svd | 0.0004 | 0.388889 | 0.408696 | 0.351304 | ... | 0.376932 | 0.380216 | 0.380603 | 0.378478 | 0.382535 | 0.375386 | 0.371522 | 0.382148 | 0.381209 | 0.008991 |
| 4 | 16.607914 | 0.090965 | 0.004099 | 0.000300 | 2 | svd | 0.00007 | 0.388889 | 0.408696 | 0.351304 | ... | 0.376932 | 0.380216 | 0.380603 | 0.378478 | 0.382535 | 0.375386 | 0.371522 | 0.382148 | 0.381209 | 0.008991 |
| 0 | 16.979295 | 0.465478 | 0.004399 | 0.000490 | 2 | svd | 0.0001 | 0.388889 | 0.408696 | 0.351304 | ... | 0.376932 | 0.380216 | 0.380603 | 0.378478 | 0.382535 | 0.375386 | 0.371522 | 0.382148 | 0.381209 | 0.008991 |
| 25 | 16.515143 | 0.071153 | 0.003299 | 0.000458 | 2 | lsqr | NaN | 0.388889 | 0.406957 | 0.351304 | ... | 0.376739 | 0.380216 | 0.380410 | 0.378284 | 0.382535 | 0.375580 | 0.371522 | 0.382342 | 0.381151 | 0.008952 |
| 26 | 16.472157 | 0.054163 | 0.003199 | 0.000400 | 2 | eigen | NaN | 0.388889 | 0.406957 | 0.351304 | ... | 0.376739 | 0.380216 | 0.380410 | 0.378284 | 0.382535 | 0.375580 | 0.371522 | 0.382342 | 0.381151 | 0.008952 |
10 rows × 32 columns
In total there are 35 different configurations tested. The best mean test score is 0.5662 There are 5 configurations with this maximum score There are 14 configurations within 0.02 of this maximum score The describe of the configurations within 0.02 of this maximum score is as follows:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_LDA__solver | param_LDA__tol | split0_test_score | split1_test_score | split2_test_score | ... | split2_train_score | split3_train_score | split4_train_score | split5_train_score | split6_train_score | split7_train_score | split8_train_score | split9_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 14.000000 | 14.000000 | 14.000000 | 14.000000 | 14.0 | 14 | 10.00000 | 14.000000 | 14.000000 | 14.000000 | ... | 14.000000 | 14.000000 | 14.000000 | 14.000000 | 14.000000 | 14.000000 | 14.000000 | 14.000000 | 14.000000 | 14.000000 |
| unique | NaN | NaN | NaN | NaN | 2.0 | 3 | 5.00000 | NaN | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| top | NaN | NaN | NaN | NaN | 10.0 | svd | 0.00007 | NaN | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| freq | NaN | NaN | NaN | NaN | 7.0 | 10 | 2.00000 | NaN | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| mean | 16.625108 | 0.077831 | 0.007484 | 0.001096 | NaN | NaN | NaN | 0.579861 | 0.563478 | 0.576273 | ... | 0.569717 | 0.565853 | 0.573664 | 0.563839 | 0.571484 | 0.573388 | 0.568517 | 0.566461 | 0.569000 | 0.005202 |
| std | 0.060383 | 0.039091 | 0.001577 | 0.000890 | NaN | NaN | NaN | 0.005405 | 0.009024 | 0.012003 | ... | 0.000346 | 0.009452 | 0.006792 | 0.010026 | 0.012462 | 0.017873 | 0.006344 | 0.007735 | 0.008553 | 0.000557 |
| min | 16.541734 | 0.043497 | 0.005499 | 0.000400 | NaN | NaN | NaN | 0.574653 | 0.554783 | 0.563478 | ... | 0.569359 | 0.556607 | 0.566847 | 0.553903 | 0.559312 | 0.556028 | 0.562403 | 0.558733 | 0.560675 | 0.004632 |
| 25% | 16.580997 | 0.057642 | 0.006073 | 0.000640 | NaN | NaN | NaN | 0.574653 | 0.554783 | 0.565217 | ... | 0.569359 | 0.556801 | 0.567233 | 0.554289 | 0.559312 | 0.556221 | 0.562403 | 0.559119 | 0.560791 | 0.004682 |
| 50% | 16.611212 | 0.069099 | 0.007298 | 0.000908 | NaN | NaN | NaN | 0.579861 | 0.563478 | 0.576522 | ... | 0.569648 | 0.565881 | 0.573512 | 0.563756 | 0.571484 | 0.573416 | 0.568489 | 0.566461 | 0.569002 | 0.005162 |
| 75% | 16.677266 | 0.081231 | 0.008523 | 0.001178 | NaN | NaN | NaN | 0.585069 | 0.572174 | 0.587826 | ... | 0.570131 | 0.574961 | 0.580371 | 0.573609 | 0.583655 | 0.590611 | 0.574575 | 0.573802 | 0.577252 | 0.005774 |
| max | 16.743470 | 0.203176 | 0.010497 | 0.004030 | NaN | NaN | NaN | 0.585069 | 0.572174 | 0.587826 | ... | 0.570131 | 0.574961 | 0.580371 | 0.573609 | 0.583655 | 0.590611 | 0.574768 | 0.574189 | 0.577252 | 0.005774 |
11 rows × 32 columns
####################################################
# GRID SEARCH RESULTS FOR SUBJECT C
####################################################
Best estimator has accuracy of 0.4781 +- 0.0185 with the following parameters
{'CSP__n_components': 10, 'LDA__solver': 'lsqr'}
Top 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_LDA__solver | param_LDA__tol | split0_test_score | split1_test_score | split2_test_score | ... | split2_train_score | split3_train_score | split4_train_score | split5_train_score | split6_train_score | split7_train_score | split8_train_score | split9_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 34 | 16.637804 | 0.222320 | 0.010898 | 0.004865 | 10 | eigen | NaN | 0.471304 | 0.469565 | 0.471304 | ... | 0.493913 | 0.488116 | 0.490048 | 0.490242 | 0.495459 | 0.500290 | 0.491208 | 0.483478 | 0.492657 | 0.005253 |
| 33 | 16.747369 | 0.125512 | 0.008198 | 0.000979 | 10 | lsqr | NaN | 0.471304 | 0.469565 | 0.471304 | ... | 0.493913 | 0.488116 | 0.490048 | 0.490242 | 0.495459 | 0.500290 | 0.491208 | 0.483478 | 0.492657 | 0.005253 |
| 24 | 16.769962 | 0.142414 | 0.008398 | 0.000663 | 10 | svd | 0.00007 | 0.473043 | 0.469565 | 0.471304 | ... | 0.494106 | 0.488116 | 0.489855 | 0.489855 | 0.495266 | 0.500483 | 0.491014 | 0.483478 | 0.492599 | 0.005310 |
| 23 | 16.820446 | 0.161290 | 0.008798 | 0.001249 | 10 | svd | 0.0004 | 0.473043 | 0.469565 | 0.471304 | ... | 0.494106 | 0.488116 | 0.489855 | 0.489855 | 0.495266 | 0.500483 | 0.491014 | 0.483478 | 0.492599 | 0.005310 |
| 22 | 16.834341 | 0.155076 | 0.009197 | 0.000871 | 10 | svd | 0.001 | 0.473043 | 0.469565 | 0.471304 | ... | 0.494106 | 0.488116 | 0.489855 | 0.489855 | 0.495266 | 0.500483 | 0.491014 | 0.483478 | 0.492599 | 0.005310 |
| 21 | 16.754367 | 0.151045 | 0.008598 | 0.000663 | 10 | svd | 0.00001 | 0.473043 | 0.469565 | 0.471304 | ... | 0.494106 | 0.488116 | 0.489855 | 0.489855 | 0.495266 | 0.500483 | 0.491014 | 0.483478 | 0.492599 | 0.005310 |
| 20 | 16.770862 | 0.154908 | 0.009497 | 0.001025 | 10 | svd | 0.0001 | 0.473043 | 0.469565 | 0.471304 | ... | 0.494106 | 0.488116 | 0.489855 | 0.489855 | 0.495266 | 0.500483 | 0.491014 | 0.483478 | 0.492599 | 0.005310 |
| 32 | 16.703383 | 0.082764 | 0.005499 | 0.000500 | 6 | eigen | NaN | 0.452174 | 0.427826 | 0.412174 | ... | 0.457778 | 0.456039 | 0.457198 | 0.445217 | 0.447150 | 0.446377 | 0.451787 | 0.444251 | 0.448290 | 0.007146 |
| 31 | 16.759965 | 0.146001 | 0.005899 | 0.000830 | 6 | lsqr | NaN | 0.452174 | 0.427826 | 0.412174 | ... | 0.457778 | 0.456039 | 0.457198 | 0.445217 | 0.447150 | 0.446377 | 0.451787 | 0.444251 | 0.448290 | 0.007146 |
| 19 | 16.657498 | 0.118762 | 0.005799 | 0.000400 | 6 | svd | 0.00007 | 0.450435 | 0.427826 | 0.413913 | ... | 0.457585 | 0.455845 | 0.457198 | 0.444831 | 0.447150 | 0.446763 | 0.451594 | 0.444058 | 0.448232 | 0.007070 |
10 rows × 32 columns
Worst 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_LDA__solver | param_LDA__tol | split0_test_score | split1_test_score | split2_test_score | ... | split2_train_score | split3_train_score | split4_train_score | split5_train_score | split6_train_score | split7_train_score | split8_train_score | split9_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 27 | 16.665595 | 0.175537 | 0.004299 | 0.000458 | 3 | lsqr | NaN | 0.415652 | 0.358261 | 0.354783 | ... | 0.393816 | 0.378744 | 0.417971 | 0.386087 | 0.401932 | 0.387246 | 0.387053 | 0.384348 | 0.393681 | 0.012063 |
| 28 | 16.651799 | 0.129599 | 0.004199 | 0.000400 | 3 | eigen | NaN | 0.415652 | 0.358261 | 0.354783 | ... | 0.393816 | 0.378744 | 0.417971 | 0.386087 | 0.401932 | 0.387246 | 0.387053 | 0.384348 | 0.393681 | 0.012063 |
| 5 | 16.605414 | 0.171615 | 0.004299 | 0.000458 | 3 | svd | 0.0001 | 0.415652 | 0.358261 | 0.354783 | ... | 0.393623 | 0.378744 | 0.418164 | 0.386280 | 0.402319 | 0.387246 | 0.387053 | 0.384348 | 0.393720 | 0.012089 |
| 25 | 16.640303 | 0.148570 | 0.003799 | 0.000600 | 2 | lsqr | NaN | 0.379130 | 0.330435 | 0.347826 | ... | 0.378937 | 0.374300 | 0.413720 | 0.371594 | 0.374493 | 0.371981 | 0.372560 | 0.370821 | 0.377913 | 0.012250 |
| 26 | 16.624708 | 0.135586 | 0.003200 | 0.000400 | 2 | eigen | NaN | 0.379130 | 0.330435 | 0.347826 | ... | 0.378937 | 0.374300 | 0.413720 | 0.371594 | 0.374493 | 0.371981 | 0.372560 | 0.370821 | 0.377913 | 0.012250 |
| 1 | 16.639603 | 0.168400 | 0.003900 | 0.000538 | 2 | svd | 0.00001 | 0.379130 | 0.328696 | 0.347826 | ... | 0.378937 | 0.374300 | 0.413913 | 0.371594 | 0.374493 | 0.371981 | 0.372754 | 0.370628 | 0.377932 | 0.012309 |
| 3 | 16.574024 | 0.168459 | 0.003699 | 0.000640 | 2 | svd | 0.0004 | 0.379130 | 0.328696 | 0.347826 | ... | 0.378937 | 0.374300 | 0.413913 | 0.371594 | 0.374493 | 0.371981 | 0.372754 | 0.370628 | 0.377932 | 0.012309 |
| 4 | 16.561728 | 0.165540 | 0.004099 | 0.000300 | 2 | svd | 0.00007 | 0.379130 | 0.328696 | 0.347826 | ... | 0.378937 | 0.374300 | 0.413913 | 0.371594 | 0.374493 | 0.371981 | 0.372754 | 0.370628 | 0.377932 | 0.012309 |
| 2 | 16.584820 | 0.147368 | 0.003999 | 0.000447 | 2 | svd | 0.001 | 0.379130 | 0.328696 | 0.347826 | ... | 0.378937 | 0.374300 | 0.413913 | 0.371594 | 0.374493 | 0.371981 | 0.372754 | 0.370628 | 0.377932 | 0.012309 |
| 0 | 16.605814 | 0.187090 | 0.004599 | 0.000800 | 2 | svd | 0.0001 | 0.379130 | 0.328696 | 0.347826 | ... | 0.378937 | 0.374300 | 0.413913 | 0.371594 | 0.374493 | 0.371981 | 0.372754 | 0.370628 | 0.377932 | 0.012309 |
10 rows × 32 columns
In total there are 35 different configurations tested. The best mean test score is 0.4781 There are 2 configurations with this maximum score There are 7 configurations within 0.02 of this maximum score The describe of the configurations within 0.02 of this maximum score is as follows:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_LDA__solver | param_LDA__tol | split0_test_score | split1_test_score | split2_test_score | ... | split2_train_score | split3_train_score | split4_train_score | split5_train_score | split6_train_score | split7_train_score | split8_train_score | split9_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 7.000000 | 7.000000 | 7.000000 | 7.000000 | 7.0 | 7 | 5.00000 | 7.000000 | 7.000000 | 7.000000 | ... | 7.000000 | 7.000000 | 7.000000 | 7.000000 | 7.000000 | 7.000000 | 7.000000 | 7.000000e+00 | 7.000000 | 7.000000 |
| unique | NaN | NaN | NaN | NaN | 1.0 | 3 | 5.00000 | NaN | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| top | NaN | NaN | NaN | NaN | 10.0 | svd | 0.00007 | NaN | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| freq | NaN | NaN | NaN | NaN | 7.0 | 5 | 1.00000 | NaN | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| mean | 16.762164 | 0.158938 | 0.009083 | 0.001474 | NaN | NaN | NaN | 0.472547 | 0.469565 | 0.471304 | ... | 0.494051 | 0.488116 | 0.489910 | 0.489965 | 0.495321 | 0.500428 | 0.491070 | 4.834783e-01 | 0.492616 | 0.005294 |
| std | 0.063885 | 0.030295 | 0.000917 | 0.001510 | NaN | NaN | NaN | 0.000849 | 0.000000 | 0.000000 | ... | 0.000094 | 0.000000 | 0.000094 | 0.000189 | 0.000094 | 0.000094 | 0.000094 | 5.995890e-17 | 0.000028 | 0.000028 |
| min | 16.637804 | 0.125512 | 0.008198 | 0.000663 | NaN | NaN | NaN | 0.471304 | 0.469565 | 0.471304 | ... | 0.493913 | 0.488116 | 0.489855 | 0.489855 | 0.495266 | 0.500290 | 0.491014 | 4.834783e-01 | 0.492599 | 0.005253 |
| 25% | 16.750868 | 0.146729 | 0.008498 | 0.000767 | NaN | NaN | NaN | 0.472174 | 0.469565 | 0.471304 | ... | 0.494010 | 0.488116 | 0.489855 | 0.489855 | 0.495266 | 0.500386 | 0.491014 | 4.834783e-01 | 0.492599 | 0.005281 |
| 50% | 16.769962 | 0.154908 | 0.008798 | 0.000979 | NaN | NaN | NaN | 0.473043 | 0.469565 | 0.471304 | ... | 0.494106 | 0.488116 | 0.489855 | 0.489855 | 0.495266 | 0.500483 | 0.491014 | 4.834783e-01 | 0.492599 | 0.005310 |
| 75% | 16.795654 | 0.158183 | 0.009347 | 0.001137 | NaN | NaN | NaN | 0.473043 | 0.469565 | 0.471304 | ... | 0.494106 | 0.488116 | 0.489952 | 0.490048 | 0.495362 | 0.500483 | 0.491111 | 4.834783e-01 | 0.492628 | 0.005310 |
| max | 16.834341 | 0.222320 | 0.010898 | 0.004865 | NaN | NaN | NaN | 0.473043 | 0.469565 | 0.471304 | ... | 0.494106 | 0.488116 | 0.490048 | 0.490242 | 0.495459 | 0.500483 | 0.491208 | 4.834783e-01 | 0.492657 | 0.005310 |
11 rows × 32 columns
####################################################
# GRID SEARCH RESULTS FOR SUBJECT E
####################################################
Best estimator has accuracy of 0.5567 +- 0.0287 with the following parameters
{'CSP__n_components': 10, 'LDA__solver': 'svd', 'LDA__tol': 0.0001}
Top 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_LDA__solver | param_LDA__tol | split0_test_score | split1_test_score | split2_test_score | ... | split2_train_score | split3_train_score | split4_train_score | split5_train_score | split6_train_score | split7_train_score | split8_train_score | split9_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 34 | 16.456062 | 0.292352 | 0.008201 | 0.000749 | 10 | eigen | NaN | 0.548611 | 0.536458 | 0.630208 | ... | 0.574213 | 0.561089 | 0.593901 | 0.567072 | 0.578074 | 0.550174 | 0.572366 | 0.567156 | 0.570591 | 0.010733 |
| 24 | 16.736373 | 0.082024 | 0.008798 | 0.001400 | 10 | svd | 0.00007 | 0.548611 | 0.536458 | 0.628472 | ... | 0.574599 | 0.560896 | 0.593322 | 0.567265 | 0.577881 | 0.550560 | 0.573138 | 0.566770 | 0.570707 | 0.010584 |
| 23 | 16.748569 | 0.043629 | 0.008797 | 0.001166 | 10 | svd | 0.0004 | 0.548611 | 0.536458 | 0.628472 | ... | 0.574599 | 0.560896 | 0.593322 | 0.567265 | 0.577881 | 0.550560 | 0.573138 | 0.566770 | 0.570707 | 0.010584 |
| 22 | 16.798453 | 0.076071 | 0.008897 | 0.000943 | 10 | svd | 0.001 | 0.548611 | 0.536458 | 0.628472 | ... | 0.574599 | 0.560896 | 0.593322 | 0.567265 | 0.577881 | 0.550560 | 0.573138 | 0.566770 | 0.570707 | 0.010584 |
| 21 | 16.773561 | 0.056015 | 0.009198 | 0.001248 | 10 | svd | 0.00001 | 0.548611 | 0.536458 | 0.628472 | ... | 0.574599 | 0.560896 | 0.593322 | 0.567265 | 0.577881 | 0.550560 | 0.573138 | 0.566770 | 0.570707 | 0.010584 |
| 20 | 16.751068 | 0.072718 | 0.009197 | 0.000600 | 10 | svd | 0.0001 | 0.548611 | 0.536458 | 0.628472 | ... | 0.574599 | 0.560896 | 0.593322 | 0.567265 | 0.577881 | 0.550560 | 0.573138 | 0.566770 | 0.570707 | 0.010584 |
| 33 | 16.653899 | 0.053267 | 0.007998 | 0.001000 | 10 | lsqr | NaN | 0.548611 | 0.536458 | 0.630208 | ... | 0.574213 | 0.561089 | 0.593901 | 0.567072 | 0.578074 | 0.550174 | 0.572366 | 0.567156 | 0.570591 | 0.010733 |
| 32 | 16.607613 | 0.055564 | 0.005998 | 0.001414 | 6 | eigen | NaN | 0.527778 | 0.527778 | 0.616319 | ... | 0.567072 | 0.556842 | 0.577495 | 0.566493 | 0.557421 | 0.539367 | 0.569085 | 0.556156 | 0.562408 | 0.009905 |
| 31 | 16.621209 | 0.053444 | 0.005699 | 0.000900 | 6 | lsqr | NaN | 0.527778 | 0.527778 | 0.616319 | ... | 0.567072 | 0.556842 | 0.577495 | 0.566493 | 0.557421 | 0.539367 | 0.569085 | 0.556156 | 0.562408 | 0.009905 |
| 19 | 16.687588 | 0.088495 | 0.005699 | 0.000458 | 6 | svd | 0.00007 | 0.526042 | 0.526042 | 0.614583 | ... | 0.567072 | 0.556842 | 0.577688 | 0.565914 | 0.557614 | 0.539560 | 0.569471 | 0.556542 | 0.562465 | 0.009851 |
10 rows × 32 columns
Worst 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_LDA__solver | param_LDA__tol | split0_test_score | split1_test_score | split2_test_score | ... | split2_train_score | split3_train_score | split4_train_score | split5_train_score | split6_train_score | split7_train_score | split8_train_score | split9_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 5 | 16.534537 | 0.047167 | 0.005599 | 4.150633e-03 | 3 | svd | 0.0001 | 0.371528 | 0.373264 | 0.420139 | ... | 0.394132 | 0.401081 | 0.471531 | 0.394518 | 0.397414 | 0.429371 | 0.396758 | 0.401196 | 0.408083 | 0.023325 |
| 27 | 16.577423 | 0.064266 | 0.003999 | 2.870940e-07 | 3 | lsqr | NaN | 0.371528 | 0.373264 | 0.420139 | ... | 0.393553 | 0.400888 | 0.471531 | 0.394711 | 0.397414 | 0.429564 | 0.396758 | 0.401775 | 0.408102 | 0.023361 |
| 28 | 16.528239 | 0.056888 | 0.004499 | 6.705054e-04 | 3 | eigen | NaN | 0.371528 | 0.373264 | 0.420139 | ... | 0.393553 | 0.400888 | 0.471531 | 0.394711 | 0.397414 | 0.429564 | 0.396758 | 0.401775 | 0.408102 | 0.023361 |
| 1 | 16.536736 | 0.046586 | 0.004299 | 6.400465e-04 | 2 | svd | 0.00001 | 0.378472 | 0.369792 | 0.423611 | ... | 0.395870 | 0.401274 | 0.346651 | 0.399151 | 0.396256 | 0.394635 | 0.397144 | 0.403512 | 0.393279 | 0.015743 |
| 2 | 16.527839 | 0.069564 | 0.004499 | 4.997493e-04 | 2 | svd | 0.001 | 0.378472 | 0.369792 | 0.423611 | ... | 0.395870 | 0.401274 | 0.346651 | 0.399151 | 0.396256 | 0.394635 | 0.397144 | 0.403512 | 0.393279 | 0.015743 |
| 3 | 16.485752 | 0.073741 | 0.003999 | 1.668930e-07 | 2 | svd | 0.0004 | 0.378472 | 0.369792 | 0.423611 | ... | 0.395870 | 0.401274 | 0.346651 | 0.399151 | 0.396256 | 0.394635 | 0.397144 | 0.403512 | 0.393279 | 0.015743 |
| 4 | 16.526539 | 0.066298 | 0.004199 | 5.998136e-04 | 2 | svd | 0.00007 | 0.378472 | 0.369792 | 0.423611 | ... | 0.395870 | 0.401274 | 0.346651 | 0.399151 | 0.396256 | 0.394635 | 0.397144 | 0.403512 | 0.393279 | 0.015743 |
| 0 | 16.487652 | 0.064448 | 0.004299 | 6.401285e-04 | 2 | svd | 0.0001 | 0.378472 | 0.369792 | 0.423611 | ... | 0.395870 | 0.401274 | 0.346651 | 0.399151 | 0.396256 | 0.394635 | 0.397144 | 0.403512 | 0.393279 | 0.015743 |
| 25 | 16.569626 | 0.062833 | 0.003499 | 4.996778e-04 | 2 | lsqr | NaN | 0.378472 | 0.369792 | 0.423611 | ... | 0.395870 | 0.401274 | 0.346844 | 0.399151 | 0.396449 | 0.394635 | 0.397144 | 0.403512 | 0.393337 | 0.015697 |
| 26 | 16.523840 | 0.050800 | 0.003599 | 4.895805e-04 | 2 | eigen | NaN | 0.378472 | 0.369792 | 0.423611 | ... | 0.395870 | 0.401274 | 0.346844 | 0.399151 | 0.396449 | 0.394635 | 0.397144 | 0.403512 | 0.393337 | 0.015697 |
10 rows × 32 columns
In total there are 35 different configurations tested. The best mean test score is 0.5567 There are 7 configurations with this maximum score There are 14 configurations within 0.02 of this maximum score The describe of the configurations within 0.02 of this maximum score is as follows:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_LDA__solver | param_LDA__tol | split0_test_score | split1_test_score | split2_test_score | ... | split2_train_score | split3_train_score | split4_train_score | split5_train_score | split6_train_score | split7_train_score | split8_train_score | split9_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 14.000000 | 14.000000 | 14.000000 | 14.000000 | 14.0 | 14 | 10.00000 | 14.000000 | 14.000000 | 14.000000 | ... | 14.000000 | 14.000000 | 14.000000 | 14.000000 | 14.000000 | 14.000000 | 14.000000 | 14.000000 | 14.000000 | 14.000000 |
| unique | NaN | NaN | NaN | NaN | 2.0 | 3 | 5.00000 | NaN | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| top | NaN | NaN | NaN | NaN | 10.0 | svd | 0.00007 | NaN | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| freq | NaN | NaN | NaN | NaN | 7.0 | 10 | 2.00000 | NaN | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| mean | 16.669223 | 0.076323 | 0.007291 | 0.000888 | NaN | NaN | NaN | 0.537574 | 0.531498 | 0.622024 | ... | 0.570781 | 0.558897 | 0.585560 | 0.566645 | 0.567748 | 0.544977 | 0.571139 | 0.561656 | 0.566561 | 0.010247 |
| std | 0.089422 | 0.064186 | 0.001525 | 0.000328 | NaN | NaN | NaN | 0.011468 | 0.005180 | 0.007252 | ... | 0.003851 | 0.002133 | 0.008229 | 0.000621 | 0.010573 | 0.005681 | 0.001867 | 0.005424 | 0.004268 | 0.000398 |
| min | 16.456062 | 0.027953 | 0.005699 | 0.000447 | NaN | NaN | NaN | 0.526042 | 0.526042 | 0.614583 | ... | 0.567072 | 0.556842 | 0.577495 | 0.565914 | 0.557421 | 0.539367 | 0.569085 | 0.556156 | 0.562408 | 0.009851 |
| 25% | 16.622184 | 0.053311 | 0.005898 | 0.000625 | NaN | NaN | NaN | 0.526042 | 0.526042 | 0.614583 | ... | 0.567072 | 0.556842 | 0.577688 | 0.565914 | 0.557614 | 0.539560 | 0.569471 | 0.556542 | 0.562465 | 0.009851 |
| 50% | 16.658497 | 0.057935 | 0.006998 | 0.000885 | NaN | NaN | NaN | 0.538194 | 0.532118 | 0.622396 | ... | 0.570643 | 0.558869 | 0.585505 | 0.566782 | 0.567748 | 0.544867 | 0.570919 | 0.561656 | 0.566528 | 0.010245 |
| 75% | 16.745520 | 0.075233 | 0.008798 | 0.001124 | NaN | NaN | NaN | 0.548611 | 0.536458 | 0.628472 | ... | 0.574599 | 0.560896 | 0.593322 | 0.567265 | 0.577881 | 0.550560 | 0.573138 | 0.566770 | 0.570707 | 0.010584 |
| max | 16.798453 | 0.292352 | 0.009198 | 0.001414 | NaN | NaN | NaN | 0.548611 | 0.536458 | 0.630208 | ... | 0.574599 | 0.561089 | 0.593901 | 0.567265 | 0.578074 | 0.550560 | 0.573138 | 0.567156 | 0.570707 | 0.010733 |
11 rows × 32 columns
####################################################
# TEST RESULTS FOR BEST FOUND GRID SEARCH
####################################################
# Configure global parameters for all experiments
subject_ids_to_test = ["B", "C", "E"] # Subjects with three recordings
start_offset = -1 # One second before visual queue
end_offset = 1 # One second after visual queue
baseline = (None, 0) # Baseline correction using data before the visual queue
filter_lower_bound = 2 # Filter out any frequency below this
filter_upper_bound = 32 # Filter out any frequency above this
best_found_csp_components = [10, 10 , 10]
best_found_solver = ["svd", "lsqr", "svd"]
best_found_tol = [0.0001, 0.0001, 0.0001]
# Loop over all found results
for i in range(len(subject_ids_to_test)):
print("\n\n")
print("####################################################")
print(f"# TEST RESULTS FOR SUBJECT {subject_ids_to_test[i]}")
print("####################################################")
print("\n\n")
################# TRAINING DATA #################
with io.capture_output():
with io.capture_output():
# Determine the train subjects
train_subjects = copy.deepcopy(subject_ids_to_test)
train_subjects.remove(subject_ids_to_test[i])
mne_raws = []
# Get all training data
for train_subject in train_subjects:
mne_raws.extend(CLA_dataset.get_all_raw_mne_data_for_subject(subject_id= train_subject))
# Combine training data into singular mne raw
mne_raw = mne.concatenate_raws(mne_raws)
# Get epochs for that MNE raw
mne_epochs = CLA_dataset.get_usefull_epochs_from_raw(mne_raw,
start_offset= start_offset,
end_offset= end_offset,
baseline= baseline)
# Only keep epochs from the MI tasks
mne_epochs = mne_epochs['task/neutral', 'task/left', 'task/right']
# Load epochs into memory
mne_epochs.load_data()
# Get the labels
y_train = mne_epochs.events[:, -1]
# Use a fixed filter
mne_epochs.filter(l_freq= filter_lower_bound,
h_freq= filter_upper_bound,
picks= "all",
phase= "minimum",
fir_window= "blackman",
fir_design= "firwin",
pad= 'median',
n_jobs= -1,
verbose= False)
# Get a half second window
X_train = mne_epochs.get_data(tmin= 0.1, tmax= 0.6)
# Delete resedual vars for training data
del mne_raws
del mne_raw
del mne_epochs
################# TESTING DATA #################
with io.capture_output():
# Get test data
mne_raws = CLA_dataset.get_all_raw_mne_data_for_subject(subject_id= subject_ids_to_test[i])
# Combine test data into singular mne raw
mne_raw = mne.concatenate_raws(mne_raws)
# Get epochs for test MNE raw
mne_epochs = CLA_dataset.get_usefull_epochs_from_raw(mne_raw,
start_offset= start_offset,
end_offset= end_offset,
baseline= baseline)
# Only keep epochs from the MI tasks
mne_epochs = mne_epochs['task/neutral', 'task/left', 'task/right']
# Load epochs into memory
mne_epochs.load_data()
# Get the labels
y_test = mne_epochs.events[:, -1]
# Use a fixed filter
mne_epochs.filter(l_freq= filter_lower_bound,
h_freq= filter_upper_bound,
picks= "all",
phase= "minimum",
fir_window= "blackman",
fir_design= "firwin",
pad= 'median',
n_jobs= -1,
verbose= False)
# Get a half second window
X_test = mne_epochs.get_data(tmin= 0.1, tmax= 0.6)
# Delete resedual vars for training data
del mne_raw
del mne_epochs
del mne_raws
################# FIT AND PREDICT #################
# Make the classifier
csp = CSP(norm_trace=False,
component_order="mutual_info",
cov_est= "epoch",
n_components= best_found_csp_components[i])
lda = LinearDiscriminantAnalysis(shrinkage= None,
priors=[1/3, 1/3, 1/3],
solver= best_found_solver[i],
tol= best_found_tol[i])
# Configure the pipeline
pipeline = Pipeline([('CSP', csp), ('LDA', lda)])
# Fit the pipeline
with io.capture_output():
pipeline.fit(X_train, y_train)
# Get accuracy for single fit
y_pred = pipeline.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
# Print accuracy results and CM
print(f"Test accuracy for subject {subject_ids_to_test[i]}: {np.round(accuracy, 4)}")
ConfusionMatrixDisplay.from_predictions(y_true= y_test, y_pred= y_pred)
plt.show()
# plot CSP patterns estimated on train data for visualization
pipeline['CSP'].plot_patterns(CLA_dataset.get_last_raw_mne_data_for_subject(subject_id= subject_ids_to_test[i]).info, ch_type='eeg', units='Patterns (AU)', size=1.5)
plt.show()
# Remove unsused variables
del subject_ids_to_test
del best_found_csp_components
del best_found_solver
del best_found_tol
del i
del X_test
del y_test
del X_train
del y_train
del csp
del lda
del train_subjects
del train_subject
del pipeline
del y_pred
del accuracy
del start_offset
del end_offset
del baseline
del filter_lower_bound
del filter_upper_bound
#################################################### # TEST RESULTS FOR SUBJECT B #################################################### Test accuracy for subject B: 0.3961
Reading 0 ... 667799 = 0.000 ... 3338.995 secs...
#################################################### # TEST RESULTS FOR SUBJECT C #################################################### Test accuracy for subject C: 0.4731
Reading 0 ... 669399 = 0.000 ... 3346.995 secs...
#################################################### # TEST RESULTS FOR SUBJECT E #################################################### Test accuracy for subject E: 0.4098
Reading 0 ... 666999 = 0.000 ... 3334.995 secs...
This experiment works as follows:
####################################################
# GRID SEARCHING BEST PIPELINE FOR EACH SUBJECT
####################################################
# Configure global parameters for all experiments
subject_ids_to_test = ["B", "C", "E"] # Subjects with three recordings
start_offset = -1 # One second before visual queue
end_offset = 1 # One second after visual queue
baseline = (None, 0) # Baseline correction using data before the visual queue
filter_lower_bound = 2 # Filter out any frequency below this
filter_upper_bound = 32 # Filter out any frequency above this
do_experiment = False # Long experiment disabled per default
if do_experiment:
# Loop over all subjects and perform the grid search for finding the best parameters
for subject_id in subject_ids_to_test:
###################### PREPARE DATA ######################
with io.capture_output():
# Determine the train subjects
train_subjects = copy.deepcopy(subject_ids_to_test)
train_subjects.remove(subject_id)
mne_raws = []
# Get all training data
for train_subject in train_subjects:
mne_raws.extend(CLA_dataset.get_all_raw_mne_data_for_subject(subject_id= train_subject))
# Combine training data into singular mne raw
mne_raw = mne.concatenate_raws(mne_raws)
# Delete all raws since concat changes them
del mne_raws
# Get epochs for that MNE raw
mne_epochs = CLA_dataset.get_usefull_epochs_from_raw(mne_raw,
start_offset= start_offset,
end_offset= end_offset,
baseline= baseline)
# Only keep epochs from the MI tasks
mne_epochs = mne_epochs['task/neutral', 'task/left', 'task/right']
# Load epochs into memory
mne_epochs.load_data()
# Show training data
print(f"Using data from participants {train_subjects} to train for testing on participant {subject_id}")
# Get the labels
labels = mne_epochs.events[:, -1]
# Use a fixed filter
mne_epochs.filter(l_freq= filter_lower_bound,
h_freq= filter_upper_bound,
picks= "all",
phase= "minimum",
fir_window= "blackman",
fir_design= "firwin",
pad= 'median',
n_jobs= -1,
verbose= False)
# Get a half second window
mne_epochs_data = mne_epochs.get_data(tmin= 0.1, tmax= 0.6)
# Configure the pipeline components by specifying the default parameters
csp = CSP(norm_trace=False,
component_order="mutual_info",
cov_est= "epoch")
svm = SVC()
# Configure the pipeline
pipeline = Pipeline([('CSP', csp), ('SVM', svm)])
# Configure cross validation to use, more splits then before since we have more data
cv = StratifiedKFold(n_splits= 6,
shuffle= True,
random_state= 2022)
# Configure the hyperparameters to test
# NOTE: these are somewhat limited due to limited computational resources
param_grid = [{
"CSP__n_components": [4, 6, 10],
"SVM__C": [0.01, 0.1, 1, 10, 100],
"SVM__kernel": ['rbf', 'sigmoid'],
"SVM__gamma":['scale', 'auto', 10, 1, 0.1, 0.01, 0.001]}
,{
"CSP__n_components": [4, 6, 10],
"SVM__C": [0.01, 0.1, 1, 10, 100],
"SVM__kernel": ['linear']}]
# Configure the grid search
grid_search = GridSearchCV(estimator= pipeline,
param_grid= param_grid,
scoring= "accuracy",
n_jobs= -1,
refit= False, # We will do this manually
cv= cv,
verbose= 10,
return_train_score= True)
# Do the grid search on the training data
grid_search.fit(X= mne_epochs_data, y= labels)
# Store the results of the grid search
with open(f"saved_variables/2/newsubject/subject{subject_id}/gridsearch_cspsvm.pickle", 'wb') as file:
pickle.dump(grid_search, file)
# Delete vars after singular experiment
del mne_raw
del mne_epochs
del mne_epochs_data
del csp
del svm
del pipeline
del labels
del cv
del file
del grid_search
del param_grid
del train_subject
del train_subjects
# Delete vars after all experiments
del subject_id
# Del global vars
del subject_ids_to_test
del filter_lower_bound
del filter_upper_bound
del baseline
del do_experiment
del end_offset
del start_offset
The CV results are based on the training set alone and thus only look at the first two sessions. The test result is for a new, unseen session and thus scores are expected to differ.
| Subject | CSP + SVM: cross validation accuracy | CSP + SVM: test split accuracy | Config |
|---|---|---|---|
| B (Train on C&E) | 0.6039 +- 0.0104 | 0.3857 | 10 CSP components | RBF SVM with C 10 and gamma auto |
| C (Train on B&E) | 0.5169 +- 0.0163 | 0.4411 | 10 CSP components | RBF SVM with C 10 and gamma scale |
| E (Train on B&C) | 0.5736 +- 0.022 | 0.3381 | 10 CSP components | RBF SVM with C 1 and gamma scale |
It becomes clear that CSP + SVM struggles with this task as the performance is comparable to random.
####################################################
# GRID SEARCH RESULTS
####################################################
# Configure global parameters for all experiments
subject_ids_to_test = ["B", "C", "E"] # Subjects with three recordings
# Loop over all found results
for subject_id in subject_ids_to_test:
print("\n\n")
print("####################################################")
print(f"# GRID SEARCH RESULTS FOR SUBJECT {subject_id}")
print("####################################################")
print("\n\n")
# Open from file
with open(f"saved_variables/2/newsubject/subject{subject_id}/gridsearch_cspsvm.pickle", 'rb') as f:
grid_search = pickle.load(f)
# Print the results
print(f"Best estimator has accuracy of {np.round(grid_search.best_score_, 4)} +- {np.round(grid_search.cv_results_['std_test_score'][grid_search.best_index_], 4)} with the following parameters")
print(grid_search.best_params_)
# Get grid search results
grid_search_results = pd.DataFrame(grid_search.cv_results_)
# Keep relevant columns and sort on rank
grid_search_results.drop(labels='params', axis=1, inplace= True)
grid_search_results.sort_values(by=['rank_test_score'], inplace=True)
# Display grid search resulst
print("\n\n Top 10 grid search results: ")
display(grid_search_results.head(10))
print("\n\n Worst 10 grid search results: ")
display(grid_search_results.tail(10))
# Display some statistics
print(f"\n\nIn total there are {len(grid_search_results)} different configurations tested.")
max_score = grid_search_results['mean_test_score'].max()
print(f"The best mean test score is {round(max_score, 4)}")
shared_first_place_count = len(grid_search_results[grid_search_results['mean_test_score'].between(max_score, max_score)])
print(f"There are {shared_first_place_count} configurations with this maximum score")
close_first_place_count = len(grid_search_results[grid_search_results['mean_test_score'].between(max_score-0.02, max_score)])
print(f"There are {close_first_place_count} configurations within 0.02 of this maximum score")
# Display statistics for best classifiers
print("\n\nThe describe of the configurations within 0.02 of this maximum score is as follows:")
display(grid_search_results[grid_search_results['mean_test_score'].between(max_score-0.02, max_score)].describe(include="all"))
# Remove unsused variables
del f
del grid_search
del max_score
del shared_first_place_count
del close_first_place_count
del grid_search_results
del subject_ids_to_test
del subject_id
####################################################
# GRID SEARCH RESULTS FOR SUBJECT B
####################################################
Best estimator has accuracy of 0.6039 +- 0.0104 with the following parameters
{'CSP__n_components': 10, 'SVM__C': 10, 'SVM__gamma': 'auto', 'SVM__kernel': 'rbf'}
Top 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_SVM__C | param_SVM__gamma | param_SVM__kernel | split0_test_score | split1_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 184 | 16.451597 | 0.090835 | 0.374546 | 0.014835 | 10 | 10 | auto | rbf | 0.617310 | 0.589155 | ... | 0.010392 | 1 | 0.666736 | 0.683431 | 0.686144 | 0.677864 | 0.686209 | 0.687670 | 0.681342 | 0.007264 |
| 190 | 16.485419 | 0.051434 | 0.375546 | 0.015569 | 10 | 10 | 0.1 | rbf | 0.617310 | 0.589155 | ... | 0.010392 | 1 | 0.666736 | 0.683431 | 0.686144 | 0.677864 | 0.686209 | 0.687670 | 0.681342 | 0.007264 |
| 168 | 16.353461 | 0.043179 | 0.395707 | 0.007816 | 10 | 1 | scale | rbf | 0.608968 | 0.602711 | ... | 0.008346 | 3 | 0.629382 | 0.655676 | 0.648790 | 0.637596 | 0.649906 | 0.640517 | 0.643645 | 0.008758 |
| 182 | 16.514910 | 0.029591 | 0.388042 | 0.017067 | 10 | 10 | scale | rbf | 0.613139 | 0.593326 | ... | 0.010156 | 4 | 0.675501 | 0.696786 | 0.691152 | 0.685792 | 0.693511 | 0.692468 | 0.689202 | 0.006951 |
| 170 | 16.403778 | 0.128070 | 0.412868 | 0.025998 | 10 | 1 | auto | rbf | 0.613139 | 0.596455 | ... | 0.010376 | 5 | 0.627295 | 0.646912 | 0.646285 | 0.632172 | 0.647194 | 0.635093 | 0.639159 | 0.007974 |
| 176 | 16.313141 | 0.057664 | 0.402705 | 0.013845 | 10 | 1 | 0.1 | rbf | 0.613139 | 0.596455 | ... | 0.010376 | 5 | 0.627295 | 0.646912 | 0.646285 | 0.632172 | 0.647194 | 0.635093 | 0.639159 | 0.007974 |
| 206 | 16.698351 | 0.063053 | 0.399873 | 0.019081 | 10 | 100 | 0.01 | rbf | 0.615224 | 0.594369 | ... | 0.010573 | 7 | 0.620409 | 0.636477 | 0.628339 | 0.616941 | 0.630503 | 0.620697 | 0.625561 | 0.006780 |
| 174 | 16.660364 | 0.056064 | 0.450189 | 0.010173 | 10 | 1 | 1 | rbf | 0.597497 | 0.581856 | ... | 0.007144 | 8 | 0.836185 | 0.844115 | 0.868114 | 0.849364 | 0.856040 | 0.859378 | 0.852199 | 0.010403 |
| 192 | 16.346963 | 0.070117 | 0.406704 | 0.016742 | 10 | 10 | 0.01 | rbf | 0.610010 | 0.591241 | ... | 0.012555 | 9 | 0.606845 | 0.617487 | 0.618322 | 0.600668 | 0.614646 | 0.604423 | 0.610399 | 0.006759 |
| 198 | 18.301841 | 0.081918 | 0.373048 | 0.014353 | 10 | 100 | auto | rbf | 0.606882 | 0.574557 | ... | 0.012607 | 10 | 0.732888 | 0.748539 | 0.761686 | 0.750887 | 0.760901 | 0.759858 | 0.752460 | 0.010100 |
10 rows × 25 columns
Worst 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_SVM__C | param_SVM__gamma | param_SVM__kernel | split0_test_score | split1_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 13 | 16.291980 | 0.058926 | 0.137456 | 0.001607 | 4 | 0.01 | 0.001 | sigmoid | 0.336809 | 0.337852 | ... | 0.000388 | 204 | 0.337437 | 0.337229 | 0.337229 | 0.337367 | 0.337367 | 0.337367 | 0.337333 | 0.000078 |
| 12 | 16.322637 | 0.062793 | 0.448357 | 0.009048 | 4 | 0.01 | 0.001 | rbf | 0.336809 | 0.337852 | ... | 0.000388 | 204 | 0.337437 | 0.337229 | 0.337229 | 0.337367 | 0.337367 | 0.337367 | 0.337333 | 0.000078 |
| 11 | 16.253992 | 0.034661 | 0.144787 | 0.007054 | 4 | 0.01 | 0.01 | sigmoid | 0.336809 | 0.337852 | ... | 0.000388 | 204 | 0.337437 | 0.337229 | 0.337229 | 0.337367 | 0.337367 | 0.337367 | 0.337333 | 0.000078 |
| 153 | 16.645035 | 0.076743 | 0.174444 | 0.004855 | 10 | 0.01 | 0.001 | sigmoid | 0.336809 | 0.337852 | ... | 0.000388 | 204 | 0.337437 | 0.337229 | 0.337229 | 0.337367 | 0.337367 | 0.337367 | 0.337333 | 0.000078 |
| 10 | 16.272153 | 0.087136 | 0.441359 | 0.010591 | 4 | 0.01 | 0.01 | rbf | 0.336809 | 0.337852 | ... | 0.000388 | 204 | 0.337437 | 0.337229 | 0.337229 | 0.337367 | 0.337367 | 0.337367 | 0.337333 | 0.000078 |
| 81 | 17.670208 | 1.205181 | 0.164448 | 0.023107 | 6 | 0.01 | 0.01 | sigmoid | 0.336809 | 0.336809 | ... | 0.000348 | 221 | 0.337437 | 0.337229 | 0.337229 | 0.337367 | 0.337367 | 0.337367 | 0.337333 | 0.000078 |
| 26 | 16.290480 | 0.037784 | 0.441526 | 0.007801 | 4 | 0.1 | 0.001 | rbf | 0.336809 | 0.336809 | ... | 0.000348 | 221 | 0.337437 | 0.337646 | 0.337437 | 0.337993 | 0.337576 | 0.337576 | 0.337611 | 0.000187 |
| 151 | 16.676191 | 0.081275 | 0.172279 | 0.005989 | 10 | 0.01 | 0.01 | sigmoid | 0.336809 | 0.336809 | ... | 0.000348 | 221 | 0.337437 | 0.337229 | 0.337229 | 0.337367 | 0.337367 | 0.337367 | 0.337333 | 0.000078 |
| 97 | 16.509911 | 0.074819 | 0.148286 | 0.002867 | 6 | 0.1 | 0.001 | sigmoid | 0.336809 | 0.336809 | ... | 0.000348 | 221 | 0.337437 | 0.337229 | 0.337229 | 0.337367 | 0.337367 | 0.337367 | 0.337333 | 0.000078 |
| 167 | 16.646867 | 0.037326 | 0.170612 | 0.003542 | 10 | 0.1 | 0.001 | sigmoid | 0.336809 | 0.336809 | ... | 0.000348 | 221 | 0.337437 | 0.337229 | 0.337229 | 0.337367 | 0.337367 | 0.337367 | 0.337333 | 0.000078 |
10 rows × 25 columns
In total there are 225 different configurations tested. The best mean test score is 0.6039 There are 2 configurations with this maximum score There are 13 configurations within 0.02 of this maximum score The describe of the configurations within 0.02 of this maximum score is as follows:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_SVM__C | param_SVM__gamma | param_SVM__kernel | split0_test_score | split1_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 13.000000 | 13.000000 | 13.000000 | 13.000000 | 13.0 | 13.0 | 13 | 13 | 13.000000 | 13.000000 | ... | 13.000000 | 13.000000 | 13.000000 | 13.000000 | 13.000000 | 13.000000 | 13.000000 | 13.000000 | 13.000000 | 13.000000 |
| unique | NaN | NaN | NaN | NaN | 1.0 | 4.0 | 5 | 1 | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| top | NaN | NaN | NaN | NaN | 10.0 | 10.0 | auto | rbf | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| freq | NaN | NaN | NaN | NaN | 13.0 | 4.0 | 3 | 13 | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| mean | 16.754487 | 0.067633 | 0.406127 | 0.015502 | NaN | NaN | NaN | NaN | 0.607764 | 0.590599 | ... | 0.010707 | 6.769231 | 0.668534 | 0.685742 | 0.687251 | 0.676146 | 0.686755 | 0.683160 | 0.681265 | 0.008057 |
| std | 0.668553 | 0.024613 | 0.036576 | 0.004522 | NaN | NaN | NaN | NaN | 0.009622 | 0.010450 | ... | 0.001700 | 3.961352 | 0.065957 | 0.064432 | 0.072283 | 0.071606 | 0.069273 | 0.073480 | 0.069425 | 0.001438 |
| min | 16.313141 | 0.029591 | 0.360884 | 0.007816 | NaN | NaN | NaN | NaN | 0.583942 | 0.574557 | ... | 0.007144 | 1.000000 | 0.599958 | 0.610810 | 0.611227 | 0.600668 | 0.614646 | 0.604423 | 0.608242 | 0.005939 |
| 25% | 16.403778 | 0.056064 | 0.375546 | 0.013845 | NaN | NaN | NaN | NaN | 0.606882 | 0.582899 | ... | 0.010376 | 4.000000 | 0.627295 | 0.646912 | 0.646285 | 0.632172 | 0.647194 | 0.635093 | 0.639159 | 0.006951 |
| 50% | 16.495082 | 0.063053 | 0.399873 | 0.014835 | NaN | NaN | NaN | NaN | 0.610010 | 0.591241 | ... | 0.010392 | 7.000000 | 0.666736 | 0.683431 | 0.680092 | 0.677029 | 0.678281 | 0.687670 | 0.681273 | 0.007974 |
| 75% | 16.660364 | 0.081918 | 0.412868 | 0.017067 | NaN | NaN | NaN | NaN | 0.613139 | 0.596455 | ... | 0.012555 | 10.000000 | 0.675501 | 0.696786 | 0.691152 | 0.685792 | 0.693511 | 0.692468 | 0.689202 | 0.008758 |
| max | 18.301841 | 0.128070 | 0.484012 | 0.025998 | NaN | NaN | NaN | NaN | 0.617310 | 0.611053 | ... | 0.012955 | 13.000000 | 0.836185 | 0.844115 | 0.868114 | 0.849364 | 0.856040 | 0.859378 | 0.852199 | 0.010403 |
11 rows × 25 columns
####################################################
# GRID SEARCH RESULTS FOR SUBJECT C
####################################################
Best estimator has accuracy of 0.5169 +- 0.0163 with the following parameters
{'CSP__n_components': 10, 'SVM__C': 10, 'SVM__gamma': 'scale', 'SVM__kernel': 'rbf'}
Top 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_SVM__C | param_SVM__gamma | param_SVM__kernel | split0_test_score | split1_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 190 | 16.936109 | 0.076218 | 0.419033 | 0.004877 | 10 | 10 | 0.1 | rbf | 0.526590 | 0.488008 | ... | 0.018261 | 1 | 0.623043 | 0.625339 | 0.603715 | 0.611853 | 0.613105 | 0.619366 | 0.616070 | 0.007356 |
| 182 | 16.872962 | 0.076008 | 0.422532 | 0.010124 | 10 | 10 | scale | rbf | 0.524505 | 0.490094 | ... | 0.016274 | 1 | 0.629931 | 0.641411 | 0.616444 | 0.621452 | 0.626669 | 0.638147 | 0.629009 | 0.008743 |
| 184 | 16.834975 | 0.075462 | 0.417367 | 0.004030 | 10 | 10 | auto | rbf | 0.526590 | 0.488008 | ... | 0.018261 | 1 | 0.623043 | 0.625339 | 0.603715 | 0.611853 | 0.613105 | 0.619366 | 0.616070 | 0.007356 |
| 198 | 18.848334 | 0.078706 | 0.411035 | 0.008893 | 10 | 100 | auto | rbf | 0.527633 | 0.480709 | ... | 0.017867 | 4 | 0.705907 | 0.700689 | 0.690526 | 0.691569 | 0.694699 | 0.701795 | 0.697531 | 0.005641 |
| 204 | 18.860830 | 0.094275 | 0.419366 | 0.017646 | 10 | 100 | 0.1 | rbf | 0.527633 | 0.480709 | ... | 0.017867 | 4 | 0.705907 | 0.700689 | 0.690526 | 0.691569 | 0.694699 | 0.701795 | 0.697531 | 0.005641 |
| 168 | 16.596883 | 0.087531 | 0.440693 | 0.008930 | 10 | 1 | scale | rbf | 0.518248 | 0.490094 | ... | 0.014571 | 4 | 0.573575 | 0.575037 | 0.564065 | 0.565735 | 0.564065 | 0.566361 | 0.568140 | 0.004458 |
| 170 | 16.632706 | 0.083044 | 0.432529 | 0.006444 | 10 | 1 | auto | rbf | 0.514077 | 0.495308 | ... | 0.011957 | 7 | 0.570027 | 0.562722 | 0.555718 | 0.557596 | 0.560100 | 0.556135 | 0.560383 | 0.004935 |
| 176 | 16.608713 | 0.116959 | 0.430862 | 0.006026 | 10 | 1 | 0.1 | rbf | 0.514077 | 0.495308 | ... | 0.011957 | 7 | 0.570027 | 0.562722 | 0.555718 | 0.557596 | 0.560100 | 0.556135 | 0.560383 | 0.004935 |
| 196 | 18.998786 | 0.067396 | 0.415034 | 0.008150 | 10 | 100 | scale | rbf | 0.520334 | 0.486966 | ... | 0.015700 | 9 | 0.717595 | 0.728867 | 0.709725 | 0.713481 | 0.719533 | 0.737270 | 0.721078 | 0.009345 |
| 174 | 16.931444 | 0.080158 | 0.464685 | 0.006742 | 10 | 1 | 1 | rbf | 0.524505 | 0.479666 | ... | 0.017687 | 10 | 0.822793 | 0.815905 | 0.813230 | 0.805718 | 0.821995 | 0.820952 | 0.816765 | 0.006007 |
10 rows × 25 columns
Worst 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_SVM__C | param_SVM__gamma | param_SVM__kernel | split0_test_score | split1_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 146 | 16.916615 | 0.059122 | 0.493509 | 0.009283 | 10 | 0.01 | 1 | rbf | 0.336809 | 0.336809 | ... | 0.000166 | 197 | 0.337090 | 0.337090 | 0.337020 | 0.337020 | 0.337020 | 0.337020 | 0.337043 | 0.000033 |
| 144 | 17.076564 | 0.078158 | 0.484512 | 0.004383 | 10 | 0.01 | 10 | rbf | 0.336809 | 0.336809 | ... | 0.000166 | 197 | 0.337090 | 0.337090 | 0.337020 | 0.337020 | 0.337020 | 0.337020 | 0.337043 | 0.000033 |
| 13 | 16.472923 | 0.085952 | 0.140122 | 0.006412 | 4 | 0.01 | 0.001 | sigmoid | 0.336809 | 0.336809 | ... | 0.000166 | 197 | 0.337090 | 0.337090 | 0.337020 | 0.337020 | 0.337020 | 0.337020 | 0.337043 | 0.000033 |
| 80 | 16.621209 | 0.106626 | 0.455022 | 0.009937 | 6 | 0.01 | 0.01 | rbf | 0.336809 | 0.336809 | ... | 0.000166 | 197 | 0.337090 | 0.337090 | 0.337020 | 0.337020 | 0.337020 | 0.337020 | 0.337043 | 0.000033 |
| 74 | 16.854301 | 0.109889 | 0.452689 | 0.007709 | 6 | 0.01 | 10 | rbf | 0.336809 | 0.336809 | ... | 0.000166 | 197 | 0.337090 | 0.337090 | 0.337020 | 0.337020 | 0.337020 | 0.337020 | 0.337043 | 0.000033 |
| 19 | 16.125033 | 0.093953 | 0.138456 | 0.013909 | 4 | 0.1 | 10 | sigmoid | 0.351408 | 0.336809 | ... | 0.010198 | 221 | 0.349614 | 0.337090 | 0.342237 | 0.317613 | 0.322830 | 0.334725 | 0.334018 | 0.010917 |
| 33 | 16.080714 | 0.118058 | 0.138122 | 0.018688 | 4 | 1 | 10 | sigmoid | 0.350365 | 0.333681 | ... | 0.008590 | 222 | 0.348570 | 0.337925 | 0.343072 | 0.316361 | 0.326795 | 0.335977 | 0.334783 | 0.010593 |
| 61 | 16.107872 | 0.076417 | 0.135123 | 0.016602 | 4 | 100 | 10 | sigmoid | 0.346194 | 0.333681 | ... | 0.006708 | 223 | 0.347944 | 0.336255 | 0.340776 | 0.317404 | 0.327421 | 0.335559 | 0.334227 | 0.009715 |
| 47 | 16.040227 | 0.089166 | 0.138122 | 0.015044 | 4 | 10 | 10 | sigmoid | 0.346194 | 0.333681 | ... | 0.006765 | 224 | 0.347944 | 0.336255 | 0.340985 | 0.317404 | 0.327003 | 0.335768 | 0.334227 | 0.009794 |
| 5 | 16.775160 | 0.077913 | 0.173944 | 0.007914 | 4 | 0.01 | 10 | sigmoid | 0.326382 | 0.314911 | ... | 0.011604 | 225 | 0.330620 | 0.321645 | 0.356010 | 0.308431 | 0.328673 | 0.314065 | 0.326574 | 0.015260 |
10 rows × 25 columns
In total there are 225 different configurations tested. The best mean test score is 0.5169 There are 3 configurations with this maximum score There are 12 configurations within 0.02 of this maximum score The describe of the configurations within 0.02 of this maximum score is as follows:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_SVM__C | param_SVM__gamma | param_SVM__kernel | split0_test_score | split1_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 12.000000 | 12.000000 | 12.000000 | 12.000000 | 12.0 | 12.0 | 12.0 | 12 | 12.000000 | 12.000000 | ... | 12.000000 | 12.000000 | 12.000000 | 12.000000 | 12.000000 | 12.000000 | 12.000000 | 12.000000 | 12.000000 | 12.000000 |
| unique | NaN | NaN | NaN | NaN | 1.0 | 4.0 | 5.0 | 1 | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| top | NaN | NaN | NaN | NaN | 10.0 | 100.0 | 0.1 | rbf | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| freq | NaN | NaN | NaN | NaN | 12.0 | 4.0 | 3.0 | 12 | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| mean | 17.335357 | 0.086649 | 0.432334 | 0.008184 | NaN | NaN | NaN | NaN | 0.519639 | 0.489225 | ... | 0.014755 | 5.916667 | 0.643777 | 0.646264 | 0.631852 | 0.634738 | 0.635364 | 0.642407 | 0.639067 | 0.006944 |
| std | 0.956418 | 0.014418 | 0.022336 | 0.003484 | NaN | NaN | NaN | NaN | 0.008683 | 0.006986 | ... | 0.003825 | 3.987670 | 0.080467 | 0.080388 | 0.080907 | 0.078588 | 0.083385 | 0.085259 | 0.081347 | 0.002800 |
| min | 16.596883 | 0.067396 | 0.411035 | 0.004030 | NaN | NaN | NaN | NaN | 0.498436 | 0.479666 | ... | 0.006470 | 1.000000 | 0.548528 | 0.550198 | 0.541736 | 0.544866 | 0.536311 | 0.541319 | 0.543826 | 0.004458 |
| 25% | 16.779658 | 0.076165 | 0.418616 | 0.006340 | NaN | NaN | NaN | NaN | 0.514077 | 0.485401 | ... | 0.011957 | 3.250000 | 0.572688 | 0.571958 | 0.561978 | 0.563700 | 0.563074 | 0.563804 | 0.566200 | 0.004935 |
| 50% | 16.902203 | 0.081601 | 0.425448 | 0.007767 | NaN | NaN | NaN | NaN | 0.522419 | 0.489051 | ... | 0.015987 | 5.500000 | 0.626487 | 0.633375 | 0.610079 | 0.616653 | 0.616548 | 0.628756 | 0.622540 | 0.005824 |
| 75% | 17.517507 | 0.096082 | 0.434570 | 0.008936 | NaN | NaN | NaN | NaN | 0.526590 | 0.492961 | ... | 0.017867 | 9.250000 | 0.705907 | 0.700689 | 0.690526 | 0.691569 | 0.694699 | 0.701795 | 0.697531 | 0.007702 |
| max | 18.998786 | 0.116959 | 0.486512 | 0.017646 | NaN | NaN | NaN | NaN | 0.527633 | 0.503650 | ... | 0.018261 | 12.000000 | 0.822793 | 0.815905 | 0.813230 | 0.805718 | 0.821995 | 0.820952 | 0.816765 | 0.014240 |
11 rows × 25 columns
####################################################
# GRID SEARCH RESULTS FOR SUBJECT E
####################################################
Best estimator has accuracy of 0.5736 +- 0.022 with the following parameters
{'CSP__n_components': 10, 'SVM__C': 1, 'SVM__gamma': 'scale', 'SVM__kernel': 'rbf'}
Top 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_SVM__C | param_SVM__gamma | param_SVM__kernel | split0_test_score | split1_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 168 | 16.606214 | 0.068624 | 0.428697 | 0.024861 | 10 | 1 | scale | rbf | 0.543750 | 0.609375 | ... | 0.021977 | 1 | 0.639775 | 0.617469 | 0.636231 | 0.626511 | 0.614006 | 0.624427 | 0.626403 | 0.009245 |
| 206 | 17.116218 | 0.036241 | 0.418867 | 0.015530 | 10 | 100 | 0.01 | rbf | 0.533333 | 0.610417 | ... | 0.026016 | 2 | 0.611215 | 0.594538 | 0.611215 | 0.603585 | 0.590871 | 0.601084 | 0.602085 | 0.007667 |
| 176 | 16.506412 | 0.052660 | 0.417033 | 0.009456 | 10 | 1 | 0.1 | rbf | 0.527083 | 0.606250 | ... | 0.026235 | 3 | 0.619762 | 0.600375 | 0.621639 | 0.615882 | 0.598791 | 0.608170 | 0.610770 | 0.008976 |
| 170 | 16.618377 | 0.073812 | 0.416867 | 0.011814 | 10 | 1 | auto | rbf | 0.527083 | 0.606250 | ... | 0.026235 | 3 | 0.619762 | 0.600375 | 0.621639 | 0.615882 | 0.598791 | 0.608170 | 0.610770 | 0.008976 |
| 182 | 16.804817 | 0.061682 | 0.396373 | 0.014791 | 10 | 10 | scale | rbf | 0.548958 | 0.588542 | ... | 0.014049 | 5 | 0.692099 | 0.676464 | 0.691265 | 0.680909 | 0.671947 | 0.684660 | 0.682891 | 0.007338 |
| 192 | 16.457927 | 0.058380 | 0.425364 | 0.014310 | 10 | 10 | 0.01 | rbf | 0.536458 | 0.609375 | ... | 0.025550 | 6 | 0.600167 | 0.580571 | 0.602251 | 0.591496 | 0.573989 | 0.583576 | 0.588675 | 0.010261 |
| 154 | 16.561561 | 0.051005 | 0.454521 | 0.015179 | 10 | 0.1 | scale | rbf | 0.535417 | 0.601042 | ... | 0.023951 | 7 | 0.603294 | 0.579946 | 0.607671 | 0.588370 | 0.581284 | 0.587953 | 0.591420 | 0.010495 |
| 100 | 16.316972 | 0.044895 | 0.376713 | 0.008048 | 6 | 1 | auto | rbf | 0.536458 | 0.614583 | ... | 0.027659 | 8 | 0.596206 | 0.579737 | 0.599124 | 0.580450 | 0.573155 | 0.596290 | 0.587494 | 0.010033 |
| 190 | 16.729341 | 0.054006 | 0.399372 | 0.006625 | 10 | 10 | 0.1 | rbf | 0.534375 | 0.603125 | ... | 0.022818 | 9 | 0.658120 | 0.646654 | 0.659162 | 0.649229 | 0.631722 | 0.639642 | 0.647421 | 0.009688 |
| 184 | 16.623708 | 0.047814 | 0.408203 | 0.015823 | 10 | 10 | auto | rbf | 0.534375 | 0.603125 | ... | 0.022818 | 9 | 0.658120 | 0.646654 | 0.659162 | 0.649229 | 0.631722 | 0.639642 | 0.647421 | 0.009688 |
10 rows × 25 columns
Worst 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_SVM__C | param_SVM__gamma | param_SVM__kernel | split0_test_score | split1_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 35 | 16.423605 | 0.131745 | 0.161782 | 0.007644 | 4 | 1 | 1 | sigmoid | 0.312500 | 0.311458 | ... | 0.013072 | 216 | 0.318532 | 0.311653 | 0.319575 | 0.329512 | 0.330346 | 0.328679 | 0.323050 | 0.006941 |
| 19 | 16.141195 | 0.214547 | 0.112631 | 0.010335 | 4 | 0.1 | 10 | sigmoid | 0.310417 | 0.316667 | ... | 0.011543 | 217 | 0.314572 | 0.325829 | 0.312070 | 0.331180 | 0.321801 | 0.322843 | 0.321382 | 0.006470 |
| 145 | 16.546566 | 0.075682 | 0.155117 | 0.004016 | 10 | 0.01 | 10 | sigmoid | 0.312500 | 0.311458 | ... | 0.023772 | 218 | 0.339379 | 0.329373 | 0.275797 | 0.330346 | 0.336390 | 0.329929 | 0.323536 | 0.021666 |
| 33 | 16.033729 | 0.093082 | 0.111298 | 0.013544 | 4 | 1 | 10 | sigmoid | 0.308333 | 0.315625 | ... | 0.012690 | 219 | 0.313946 | 0.324369 | 0.312904 | 0.332013 | 0.320133 | 0.323051 | 0.321070 | 0.006492 |
| 47 | 16.042726 | 0.157321 | 0.106799 | 0.007838 | 4 | 10 | 10 | sigmoid | 0.308333 | 0.315625 | ... | 0.012469 | 220 | 0.314989 | 0.325412 | 0.313946 | 0.331805 | 0.320759 | 0.323051 | 0.321660 | 0.006107 |
| 61 | 16.067385 | 0.152692 | 0.110465 | 0.006209 | 4 | 100 | 10 | sigmoid | 0.309375 | 0.315625 | ... | 0.012637 | 221 | 0.315197 | 0.325620 | 0.313738 | 0.331805 | 0.320550 | 0.323051 | 0.321660 | 0.006141 |
| 21 | 16.607712 | 0.104631 | 0.172945 | 0.006478 | 4 | 0.1 | 1 | sigmoid | 0.311458 | 0.306250 | ... | 0.015470 | 221 | 0.321451 | 0.310611 | 0.318949 | 0.327637 | 0.329721 | 0.330346 | 0.323119 | 0.006986 |
| 5 | 16.651032 | 0.078271 | 0.148786 | 0.012222 | 4 | 0.01 | 10 | sigmoid | 0.311458 | 0.256250 | ... | 0.033438 | 223 | 0.317699 | 0.284970 | 0.315614 | 0.288870 | 0.335348 | 0.320967 | 0.310578 | 0.017910 |
| 77 | 17.203523 | 0.098349 | 0.229427 | 0.004715 | 6 | 0.01 | 1 | sigmoid | 0.311458 | 0.294792 | ... | 0.023280 | 224 | 0.339170 | 0.328330 | 0.284553 | 0.336807 | 0.328679 | 0.331805 | 0.324891 | 0.018471 |
| 75 | 16.557396 | 0.113156 | 0.150285 | 0.006623 | 6 | 0.01 | 10 | sigmoid | 0.262500 | 0.309375 | ... | 0.034502 | 225 | 0.274964 | 0.323952 | 0.272670 | 0.324927 | 0.307003 | 0.326594 | 0.305018 | 0.023000 |
10 rows × 25 columns
In total there are 225 different configurations tested. The best mean test score is 0.5736 There are 1 configurations with this maximum score There are 43 configurations within 0.02 of this maximum score The describe of the configurations within 0.02 of this maximum score is as follows:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_SVM__C | param_SVM__gamma | param_SVM__kernel | split0_test_score | split1_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 43.000000 | 43.000000 | 43.000000 | 43.000000 | 43.0 | 43.0 | 38.00 | 43 | 43.000000 | 43.000000 | ... | 43.000000 | 43.000000 | 43.000000 | 43.000000 | 43.000000 | 43.000000 | 43.000000 | 43.000000 | 43.000000 | 43.000000 |
| unique | NaN | NaN | NaN | NaN | 2.0 | 5.0 | 6.00 | 3 | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| top | NaN | NaN | NaN | NaN | 10.0 | 1.0 | 0.01 | rbf | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| freq | NaN | NaN | NaN | NaN | 27.0 | 12.0 | 8.00 | 34 | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| mean | 17.055803 | 0.089568 | 0.348633 | 0.010835 | NaN | NaN | NaN | NaN | 0.536458 | 0.600388 | ... | 0.023164 | 21.883721 | 0.614226 | 0.598620 | 0.617392 | 0.604200 | 0.591899 | 0.604123 | 0.605077 | 0.009876 |
| std | 1.989175 | 0.127011 | 0.123650 | 0.005177 | NaN | NaN | NaN | NaN | 0.007429 | 0.010621 | ... | 0.004250 | 12.628783 | 0.044597 | 0.048566 | 0.042660 | 0.048044 | 0.047360 | 0.047643 | 0.046253 | 0.001656 |
| min | 16.307309 | 0.029035 | 0.089140 | 0.001572 | NaN | NaN | NaN | NaN | 0.523958 | 0.567708 | ... | 0.010383 | 1.000000 | 0.568689 | 0.562018 | 0.582656 | 0.563777 | 0.552730 | 0.560859 | 0.566511 | 0.006280 |
| 25% | 16.488584 | 0.053156 | 0.369383 | 0.007385 | NaN | NaN | NaN | NaN | 0.532292 | 0.595313 | ... | 0.022153 | 11.000000 | 0.587867 | 0.566396 | 0.590682 | 0.576907 | 0.559504 | 0.570446 | 0.574553 | 0.008758 |
| 50% | 16.561561 | 0.068989 | 0.396373 | 0.011664 | NaN | NaN | NaN | NaN | 0.534375 | 0.602083 | ... | 0.023160 | 21.000000 | 0.596206 | 0.579737 | 0.602460 | 0.582743 | 0.573155 | 0.587953 | 0.587494 | 0.009863 |
| 75% | 16.761331 | 0.086541 | 0.417950 | 0.014448 | NaN | NaN | NaN | NaN | 0.538542 | 0.607812 | ... | 0.026125 | 32.500000 | 0.622577 | 0.606108 | 0.621639 | 0.615882 | 0.603168 | 0.621509 | 0.614001 | 0.011514 |
| max | 29.269849 | 0.888109 | 0.478847 | 0.024861 | NaN | NaN | NaN | NaN | 0.556250 | 0.614583 | ... | 0.028891 | 43.000000 | 0.792996 | 0.797582 | 0.793621 | 0.803251 | 0.786578 | 0.804919 | 0.796491 | 0.012247 |
11 rows × 25 columns
####################################################
# TEST RESULTS FOR BEST FOUND GRID SEARCH
####################################################
# Configure global parameters for all experiments
subject_ids_to_test = ["B", "C", "E"] # Subjects with three recordings
start_offset = -1 # One second before visual queue
end_offset = 1 # One second after visual queue
baseline = (None, 0) # Baseline correction using data before the visual queue
filter_lower_bound = 2 # Filter out any frequency below this
filter_upper_bound = 32 # Filter out any frequency above this
best_found_csp_components = [10, 10 , 10]
best_found_svm_kernel = ["rbf", "rbf", "sigmoid"]
best_found_svm_c = [10, 10, 1]
best_found_svm_gamma = ["auto", "scale", "scale"]
# Loop over all found results
for i in range(len(subject_ids_to_test)):
print("\n\n")
print("####################################################")
print(f"# TEST RESULTS FOR SUBJECT {subject_ids_to_test[i]}")
print("####################################################")
print("\n\n")
################# TRAINING DATA #################
with io.capture_output():
with io.capture_output():
# Determine the train subjects
train_subjects = copy.deepcopy(subject_ids_to_test)
train_subjects.remove(subject_ids_to_test[i])
mne_raws = []
# Get all training data
for train_subject in train_subjects:
mne_raws.extend(CLA_dataset.get_all_raw_mne_data_for_subject(subject_id= train_subject))
# Combine training data into singular mne raw
mne_raw = mne.concatenate_raws(mne_raws)
# Get epochs for that MNE raw
mne_epochs = CLA_dataset.get_usefull_epochs_from_raw(mne_raw,
start_offset= start_offset,
end_offset= end_offset,
baseline= baseline)
# Only keep epochs from the MI tasks
mne_epochs = mne_epochs['task/neutral', 'task/left', 'task/right']
# Load epochs into memory
mne_epochs.load_data()
# Get the labels
y_train = mne_epochs.events[:, -1]
# Use a fixed filter
mne_epochs.filter(l_freq= filter_lower_bound,
h_freq= filter_upper_bound,
picks= "all",
phase= "minimum",
fir_window= "blackman",
fir_design= "firwin",
pad= 'median',
n_jobs= -1,
verbose= False)
# Get a half second window
X_train = mne_epochs.get_data(tmin= 0.1, tmax= 0.6)
# Delete resedual vars for training data
del mne_raws
del mne_raw
del mne_epochs
################# TESTING DATA #################
with io.capture_output():
# Get test data
mne_raws = CLA_dataset.get_all_raw_mne_data_for_subject(subject_id= subject_ids_to_test[i])
# Combine test data into singular mne raw
mne_raw = mne.concatenate_raws(mne_raws)
# Get epochs for test MNE raw
mne_epochs = CLA_dataset.get_usefull_epochs_from_raw(mne_raw,
start_offset= start_offset,
end_offset= end_offset,
baseline= baseline)
# Only keep epochs from the MI tasks
mne_epochs = mne_epochs['task/neutral', 'task/left', 'task/right']
# Load epochs into memory
mne_epochs.load_data()
# Get the labels
y_test = mne_epochs.events[:, -1]
# Use a fixed filter
mne_epochs.filter(l_freq= filter_lower_bound,
h_freq= filter_upper_bound,
picks= "all",
phase= "minimum",
fir_window= "blackman",
fir_design= "firwin",
pad= 'median',
n_jobs= -1,
verbose= False)
# Get a half second window
X_test = mne_epochs.get_data(tmin= 0.1, tmax= 0.6)
# Delete resedual vars for training data
del mne_raw
del mne_epochs
del mne_raws
################# FIT AND PREDICT #################
# Make the classifier
csp = CSP(norm_trace=False,
component_order="mutual_info",
cov_est= "epoch",
n_components= best_found_csp_components[i])
svm = SVC(kernel= best_found_svm_kernel[i],
C= best_found_svm_c[i],
gamma= best_found_svm_gamma[i])
# Configure the pipeline
pipeline = Pipeline([('CSP', csp), ('SVM', svm)])
# Fit the pipeline
with io.capture_output():
pipeline.fit(X_train, y_train)
# Get accuracy for single fit
y_pred = pipeline.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
# Print accuracy results and CM
print(f"Test accuracy for subject {subject_ids_to_test[i]}: {np.round(accuracy, 4)}")
ConfusionMatrixDisplay.from_predictions(y_true= y_test, y_pred= y_pred)
plt.show()
# plot CSP patterns estimated on train data for visualization
pipeline['CSP'].plot_patterns(CLA_dataset.get_last_raw_mne_data_for_subject(subject_id= subject_ids_to_test[i]).info, ch_type='eeg', units='Patterns (AU)', size=1.5)
plt.show()
# Remove unsused variables
del subject_ids_to_test
del best_found_csp_components
del best_found_svm_kernel
del best_found_svm_c
del best_found_svm_gamma
del i
del X_test
del y_test
del X_train
del y_train
del csp
del svm
del train_subjects
del train_subject
del pipeline
del y_pred
del accuracy
del start_offset
del end_offset
del baseline
del filter_lower_bound
del filter_upper_bound
#################################################### # TEST RESULTS FOR SUBJECT B #################################################### Test accuracy for subject B: 0.3857
Reading 0 ... 667799 = 0.000 ... 3338.995 secs...
#################################################### # TEST RESULTS FOR SUBJECT C #################################################### Test accuracy for subject C: 0.4411
Reading 0 ... 669399 = 0.000 ... 3346.995 secs...
#################################################### # TEST RESULTS FOR SUBJECT E #################################################### Test accuracy for subject E: 0.3381
Reading 0 ... 666999 = 0.000 ... 3334.995 secs...
This experiment works as follows:
####################################################
# GRID SEARCHING BEST PIPELINE FOR EACH SUBJECT
####################################################
# Configure global parameters for all experiments
subject_ids_to_test = ["B", "C", "E"] # Subjects with three recordings
start_offset = -1 # One second before visual queue
end_offset = 1 # One second after visual queue
baseline = (None, 0) # Baseline correction using data before the visual queue
filter_lower_bound = 2 # Filter out any frequency below this
filter_upper_bound = 32 # Filter out any frequency above this
do_experiment = False # Long experiment disabled per default
if do_experiment:
# Loop over all subjects and perform the grid search for finding the best parameters
for subject_id in subject_ids_to_test:
###################### PREPARE DATA ######################
with io.capture_output():
# Determine the train subjects
train_subjects = copy.deepcopy(subject_ids_to_test)
train_subjects.remove(subject_id)
mne_raws = []
# Get all training data
for train_subject in train_subjects:
mne_raws.extend(CLA_dataset.get_all_raw_mne_data_for_subject(subject_id= train_subject))
# Combine training data into singular mne raw
mne_raw = mne.concatenate_raws(mne_raws)
# Delete all raws since concat changes them
del mne_raws
# Get epochs for that MNE raw
mne_epochs = CLA_dataset.get_usefull_epochs_from_raw(mne_raw,
start_offset= start_offset,
end_offset= end_offset,
baseline= baseline)
# Only keep epochs from the MI tasks
mne_epochs = mne_epochs['task/neutral', 'task/left', 'task/right']
# Load epochs into memory
mne_epochs.load_data()
# Show training data
print(f"Using data from participants {train_subjects} to train for testing on participant {subject_id}")
# Get the labels
labels = mne_epochs.events[:, -1]
# Use a fixed filter
mne_epochs.filter(l_freq= filter_lower_bound,
h_freq= filter_upper_bound,
picks= "all",
phase= "minimum",
fir_window= "blackman",
fir_design= "firwin",
pad= 'median',
n_jobs= -1,
verbose= False)
# Get a half second window
mne_epochs_data = mne_epochs.get_data(tmin= 0.1, tmax= 0.6)
# Configure the pipeline components by specifying the default parameters
csp = CSP(norm_trace=False,
component_order="mutual_info",
cov_est= "epoch")
rf = RandomForestClassifier(bootstrap= True,
criterion= "gini")
# Configure the pipeline
pipeline = Pipeline([('CSP', csp), ('RF', rf)])
# Configure cross validation to use, more splits then before since we have more data
cv = StratifiedKFold(n_splits= 6,
shuffle= True,
random_state= 2022)
# Configure the hyperparameters to test
# NOTE: these are somewhat limited due to limitedd computational resources
param_grid = [{"CSP__n_components": [4, 6, 10],
"RF__n_estimators": [10, 50, 100, 250, 500],
"RF__max_depth": [None, 3, 10],
"RF__min_samples_split": [2, 5, 10],
"RF__max_features": ["sqrt", "log2", "None", 0.2, 0.4, 0.6]}]
# Configure the grid search
grid_search = GridSearchCV(estimator= pipeline,
param_grid= param_grid,
scoring= "accuracy",
n_jobs= -1,
refit= False, # We will do this manually
cv= cv,
verbose= 10,
return_train_score= True)
# Do the grid search on the training data
grid_search.fit(X= mne_epochs_data, y= labels)
# Store the results of the grid search
with open(f"saved_variables/2/newsubject/subject{subject_id}/gridsearch_csprf.pickle", 'wb') as file:
pickle.dump(grid_search, file)
# Delete vars after singular experiment
del mne_raw
del mne_epochs
del mne_epochs_data
del csp
del rf
del pipeline
del labels
del cv
del file
del grid_search
del param_grid
del train_subject
del train_subjects
# Delete vars after all experiments
del subject_id
# Del global vars
del subject_ids_to_test
del filter_lower_bound
del filter_upper_bound
del baseline
del do_experiment
del end_offset
del start_offset
The CV results are based on the training set alone and thus only look at the first two sessions. The test result is for a new, unseen session and thus scores are expected to differ.
| Subject | CSP + RF: cross validation accuracy | CSP + RF: test split accuracy | Config |
|---|---|---|---|
| B (Train on C&E) | 0.5983 +- 0.0068 | 0.3923 | 10 CSP components | RF with None max depth, 0.2 max features, 2 min samples split and 500 estimators |
| C (Train on B&E) | 0.504 +- 0.0197 | 0.4571 | 10 CSP components | RF with None max depth, sqrt max features, 10 min samples split and 250 estimators |
| E (Train on B&C) | 0.572 +- 0.0274 | 0.3715 | 10 CSP components | RF with 10 max depth, log2 max features, 10 min samples split and 250 estimators |
Again, performance is poor and LDA, SVM and RF perform very equal. It is clear the limiting factor here is the CSP feature extraction rather then the ML classifier.
####################################################
# GRID SEARCH RESULTS
####################################################
# Configure global parameters for all experiments
subject_ids_to_test = ["B", "C", "E"] # Subjects with three recordings
# Loop over all found results
for subject_id in subject_ids_to_test:
print("\n\n")
print("####################################################")
print(f"# GRID SEARCH RESULTS FOR SUBJECT {subject_id}")
print("####################################################")
print("\n\n")
# Open from file
with open(f"saved_variables/2/newsubject/subject{subject_id}/gridsearch_csprf.pickle", 'rb') as f:
grid_search = pickle.load(f)
# Print the results
print(f"Best estimator has accuracy of {np.round(grid_search.best_score_, 4)} +- {np.round(grid_search.cv_results_['std_test_score'][grid_search.best_index_], 4)} with the following parameters")
print(grid_search.best_params_)
# Get grid search results
grid_search_results = pd.DataFrame(grid_search.cv_results_)
# Keep relevant columns and sort on rank
grid_search_results.drop(labels='params', axis=1, inplace= True)
grid_search_results.sort_values(by=['rank_test_score'], inplace=True)
# Display grid search resulst
print("\n\n Top 10 grid search results: ")
display(grid_search_results.head(10))
print("\n\n Worst 10 grid search results: ")
display(grid_search_results.tail(10))
# Display some statistics
print(f"\n\nIn total there are {len(grid_search_results)} different configurations tested.")
max_score = grid_search_results['mean_test_score'].max()
print(f"The best mean test score is {round(max_score, 4)}")
shared_first_place_count = len(grid_search_results[grid_search_results['mean_test_score'].between(max_score, max_score)])
print(f"There are {shared_first_place_count} configurations with this maximum score")
close_first_place_count = len(grid_search_results[grid_search_results['mean_test_score'].between(max_score-0.02, max_score)])
print(f"There are {close_first_place_count} configurations within 0.02 of this maximum score")
# Display statistics for best classifiers
print("\n\nThe describe of the configurations within 0.02 of this maximum score is as follows:")
display(grid_search_results[grid_search_results['mean_test_score'].between(max_score-0.02, max_score)].describe(include="all"))
# Remove unsused variables
del f
del grid_search
del max_score
del shared_first_place_count
del close_first_place_count
del grid_search_results
del subject_ids_to_test
del subject_id
####################################################
# GRID SEARCH RESULTS FOR SUBJECT B
####################################################
Best estimator has accuracy of 0.5983 +- 0.0068 with the following parameters
{'CSP__n_components': 10, 'RF__max_depth': None, 'RF__max_features': 0.2, 'RF__min_samples_split': 2, 'RF__n_estimators': 500}
Top 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_RF__max_depth | param_RF__max_features | param_RF__min_samples_split | param_RF__n_estimators | split0_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 589 | 20.396840 | 0.119464 | 0.145453 | 0.001384 | 10 | None | 0.2 | 2 | 500 | 0.608968 | ... | 0.006821 | 1 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 0.000000 |
| 564 | 21.546974 | 0.100174 | 0.150286 | 0.006573 | 10 | None | log2 | 5 | 500 | 0.601668 | ... | 0.008195 | 2 | 0.999791 | 0.999583 | 0.999791 | 0.999791 | 0.999165 | 0.999374 | 0.999583 | 0.000241 |
| 549 | 21.191920 | 0.102049 | 0.149952 | 0.006530 | 10 | None | sqrt | 5 | 500 | 0.605839 | ... | 0.005484 | 3 | 1.000000 | 0.999583 | 0.999583 | 0.999165 | 0.999583 | 1.000000 | 0.999652 | 0.000287 |
| 593 | 17.990606 | 0.074942 | 0.078975 | 0.001000 | 10 | None | 0.2 | 5 | 250 | 0.599583 | ... | 0.006059 | 4 | 1.000000 | 0.999374 | 0.998957 | 0.999165 | 0.999165 | 0.999374 | 0.999339 | 0.000328 |
| 599 | 19.955647 | 0.077388 | 0.139123 | 0.004945 | 10 | None | 0.2 | 10 | 500 | 0.600626 | ... | 0.004754 | 5 | 0.974541 | 0.973706 | 0.974958 | 0.975381 | 0.971834 | 0.974546 | 0.974161 | 0.001157 |
| 543 | 18.418969 | 0.092091 | 0.078309 | 0.002867 | 10 | None | sqrt | 2 | 250 | 0.606882 | ... | 0.008020 | 6 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 0.000000 |
| 562 | 17.016915 | 0.072964 | 0.038488 | 0.001500 | 10 | None | log2 | 5 | 100 | 0.595412 | ... | 0.004059 | 7 | 0.998748 | 0.998331 | 0.998539 | 0.998331 | 0.998540 | 0.998122 | 0.998435 | 0.000200 |
| 614 | 22.602305 | 0.167663 | 0.137623 | 0.006234 | 10 | None | 0.4 | 10 | 500 | 0.602711 | ... | 0.007473 | 8 | 0.977880 | 0.975584 | 0.973080 | 0.975798 | 0.973712 | 0.975172 | 0.975204 | 0.001549 |
| 554 | 21.316381 | 0.120976 | 0.138123 | 0.004057 | 10 | None | sqrt | 10 | 500 | 0.604797 | ... | 0.007279 | 9 | 0.977045 | 0.974124 | 0.974332 | 0.976007 | 0.972668 | 0.974755 | 0.974822 | 0.001397 |
| 597 | 16.653864 | 0.062934 | 0.037322 | 0.001374 | 10 | None | 0.2 | 10 | 100 | 0.603754 | ... | 0.007325 | 10 | 0.970993 | 0.968280 | 0.969533 | 0.970791 | 0.968913 | 0.969330 | 0.969640 | 0.000970 |
10 rows × 26 columns
Worst 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_RF__max_depth | param_RF__max_features | param_RF__min_samples_split | param_RF__n_estimators | split0_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 395 | 15.437252 | 0.063028 | 0.0 | 0.0 | 6 | 3 | None | 5 | 10 | NaN | ... | NaN | 801 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 480 | 15.514561 | 0.045379 | 0.0 | 0.0 | 6 | 10 | None | 2 | 10 | NaN | ... | NaN | 802 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 402 | 15.491068 | 0.079607 | 0.0 | 0.0 | 6 | 3 | None | 10 | 100 | NaN | ... | NaN | 803 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 401 | 15.473740 | 0.071803 | 0.0 | 0.0 | 6 | 3 | None | 10 | 50 | NaN | ... | NaN | 804 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 400 | 15.474574 | 0.082528 | 0.0 | 0.0 | 6 | 3 | None | 10 | 10 | NaN | ... | NaN | 805 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 399 | 15.560380 | 0.048912 | 0.0 | 0.0 | 6 | 3 | None | 5 | 500 | NaN | ... | NaN | 806 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 398 | 15.522059 | 0.037724 | 0.0 | 0.0 | 6 | 3 | None | 5 | 250 | NaN | ... | NaN | 807 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 397 | 15.457746 | 0.083142 | 0.0 | 0.0 | 6 | 3 | None | 5 | 100 | NaN | ... | NaN | 808 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 403 | 15.566711 | 0.089764 | 0.0 | 0.0 | 6 | 3 | None | 10 | 250 | NaN | ... | NaN | 809 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 404 | 15.692171 | 0.051239 | 0.0 | 0.0 | 6 | 3 | None | 10 | 500 | NaN | ... | NaN | 810 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
10 rows × 26 columns
In total there are 810 different configurations tested. The best mean test score is 0.5983 There are 1 configurations with this maximum score There are 118 configurations within 0.02 of this maximum score The describe of the configurations within 0.02 of this maximum score is as follows:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_RF__max_depth | param_RF__max_features | param_RF__min_samples_split | param_RF__n_estimators | split0_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 118.000000 | 118.000000 | 118.000000 | 1.180000e+02 | 118.0 | 59.0 | 118.0 | 118.0 | 118.0 | 118.000000 | ... | 118.000000 | 118.000000 | 118.000000 | 118.000000 | 118.000000 | 118.000000 | 118.000000 | 118.000000 | 118.000000 | 118.000000 |
| unique | NaN | NaN | NaN | NaN | 1.0 | 1.0 | 5.0 | 3.0 | 4.0 | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| top | NaN | NaN | NaN | NaN | 10.0 | 10.0 | 0.2 | 10.0 | 500.0 | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| freq | NaN | NaN | NaN | NaN | 118.0 | 59.0 | 24.0 | 40.0 | 30.0 | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| mean | 18.523367 | 0.090960 | 0.068734 | 2.968795e-03 | NaN | NaN | NaN | NaN | NaN | 0.594334 | ... | 0.007290 | 59.483051 | 0.917279 | 0.922068 | 0.927538 | 0.914200 | 0.929240 | 0.922090 | 0.922069 | 0.006007 |
| std | 2.313453 | 0.032335 | 0.044180 | 2.197256e-03 | NaN | NaN | NaN | NaN | NaN | 0.007406 | ... | 0.002076 | 34.199448 | 0.075238 | 0.069765 | 0.064850 | 0.078148 | 0.062693 | 0.069958 | 0.070048 | 0.005384 |
| min | 16.110704 | 0.028613 | 0.022993 | 2.973602e-07 | NaN | NaN | NaN | NaN | NaN | 0.571429 | ... | 0.001496 | 1.000000 | 0.802170 | 0.821369 | 0.827629 | 0.800751 | 0.825370 | 0.810557 | 0.815406 | 0.000000 |
| 25% | 16.710180 | 0.069915 | 0.033906 | 1.076640e-03 | NaN | NaN | NaN | NaN | NaN | 0.589155 | ... | 0.006116 | 30.250000 | 0.850219 | 0.859401 | 0.872757 | 0.842896 | 0.871740 | 0.860421 | 0.858799 | 0.000370 |
| 50% | 17.727023 | 0.087135 | 0.065479 | 2.362203e-03 | NaN | NaN | NaN | NaN | NaN | 0.594891 | ... | 0.007050 | 59.500000 | 0.918406 | 0.925292 | 0.930196 | 0.916336 | 0.933340 | 0.925412 | 0.924361 | 0.005790 |
| 75% | 19.958521 | 0.110311 | 0.109424 | 4.458532e-03 | NaN | NaN | NaN | NaN | NaN | 0.599583 | ... | 0.008702 | 88.750000 | 0.999165 | 0.999061 | 0.998696 | 0.998748 | 0.998748 | 0.998696 | 0.998809 | 0.011394 |
| max | 25.622677 | 0.181459 | 0.160616 | 9.119440e-03 | NaN | NaN | NaN | NaN | NaN | 0.612096 | ... | 0.013020 | 118.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 0.013684 |
11 rows × 26 columns
####################################################
# GRID SEARCH RESULTS FOR SUBJECT C
####################################################
Best estimator has accuracy of 0.504 +- 0.0197 with the following parameters
{'CSP__n_components': 10, 'RF__max_depth': None, 'RF__max_features': 'sqrt', 'RF__min_samples_split': 10, 'RF__n_estimators': 250}
Top 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_RF__max_depth | param_RF__max_features | param_RF__min_samples_split | param_RF__n_estimators | split0_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 553 | 18.635067 | 0.116426 | 0.076476 | 0.001258 | 10 | None | sqrt | 10 | 250 | 0.524505 | ... | 0.019697 | 1 | 0.986642 | 0.985181 | 0.985184 | 0.985601 | 0.987270 | 0.987688 | 0.986261 | 0.000997 |
| 769 | 19.202720 | 0.065640 | 0.122295 | 0.003036 | 10 | 10 | 0.2 | 2 | 500 | 0.508863 | ... | 0.015240 | 2 | 0.836777 | 0.828637 | 0.852045 | 0.854549 | 0.837229 | 0.835559 | 0.840799 | 0.009308 |
| 728 | 18.026761 | 0.072242 | 0.067979 | 0.005506 | 10 | 10 | sqrt | 5 | 250 | 0.514077 | ... | 0.018853 | 3 | 0.821958 | 0.811313 | 0.823664 | 0.834098 | 0.820743 | 0.813022 | 0.820800 | 0.007493 |
| 784 | 21.599124 | 0.049293 | 0.129792 | 0.005667 | 10 | 10 | 0.4 | 2 | 500 | 0.506778 | ... | 0.018060 | 4 | 0.848257 | 0.835525 | 0.851210 | 0.854967 | 0.851210 | 0.836811 | 0.846330 | 0.007453 |
| 568 | 18.640399 | 0.027713 | 0.077142 | 0.003530 | 10 | None | log2 | 10 | 250 | 0.506778 | ... | 0.014630 | 5 | 0.985598 | 0.985807 | 0.986436 | 0.986853 | 0.987062 | 0.988105 | 0.986643 | 0.000836 |
| 543 | 18.800348 | 0.091748 | 0.085306 | 0.004569 | 10 | None | sqrt | 2 | 250 | 0.516163 | ... | 0.019542 | 6 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 0.000000 |
| 793 | 18.594747 | 0.102751 | 0.065146 | 0.000687 | 10 | 10 | 0.4 | 10 | 250 | 0.505735 | ... | 0.017471 | 7 | 0.794615 | 0.777082 | 0.798623 | 0.811352 | 0.788815 | 0.789232 | 0.793286 | 0.010453 |
| 773 | 17.441281 | 0.074579 | 0.066979 | 0.004581 | 10 | 10 | 0.2 | 5 | 250 | 0.507821 | ... | 0.010542 | 8 | 0.817783 | 0.815070 | 0.830968 | 0.834516 | 0.808639 | 0.814900 | 0.820313 | 0.009262 |
| 613 | 19.381663 | 0.051025 | 0.076476 | 0.003946 | 10 | None | 0.4 | 10 | 250 | 0.517205 | ... | 0.013284 | 9 | 0.986433 | 0.984346 | 0.986018 | 0.986853 | 0.986227 | 0.986436 | 0.986052 | 0.000804 |
| 732 | 16.619709 | 0.058036 | 0.033490 | 0.000500 | 10 | 10 | sqrt | 10 | 100 | 0.504692 | ... | 0.013862 | 10 | 0.788771 | 0.775204 | 0.788189 | 0.807179 | 0.787563 | 0.781093 | 0.788000 | 0.009833 |
10 rows × 26 columns
Worst 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_RF__max_depth | param_RF__max_features | param_RF__min_samples_split | param_RF__n_estimators | split0_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 395 | 15.575874 | 0.073233 | 0.0 | 0.0 | 6 | 3 | None | 5 | 10 | NaN | ... | NaN | 801 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 480 | 15.624526 | 0.083111 | 0.0 | 0.0 | 6 | 10 | None | 2 | 10 | NaN | ... | NaN | 802 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 402 | 15.747320 | 0.060066 | 0.0 | 0.0 | 6 | 3 | None | 10 | 100 | NaN | ... | NaN | 803 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 401 | 15.591869 | 0.066770 | 0.0 | 0.0 | 6 | 3 | None | 10 | 50 | NaN | ... | NaN | 804 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 400 | 15.587371 | 0.071206 | 0.0 | 0.0 | 6 | 3 | None | 10 | 10 | NaN | ... | NaN | 805 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 399 | 15.706500 | 0.052762 | 0.0 | 0.0 | 6 | 3 | None | 5 | 500 | NaN | ... | NaN | 806 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 398 | 15.688672 | 0.048909 | 0.0 | 0.0 | 6 | 3 | None | 5 | 250 | NaN | ... | NaN | 807 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 397 | 15.570210 | 0.069659 | 0.0 | 0.0 | 6 | 3 | None | 5 | 100 | NaN | ... | NaN | 808 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 403 | 15.727326 | 0.058782 | 0.0 | 0.0 | 6 | 3 | None | 10 | 250 | NaN | ... | NaN | 809 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 404 | 15.805801 | 0.062113 | 0.0 | 0.0 | 6 | 3 | None | 10 | 500 | NaN | ... | NaN | 810 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
10 rows × 26 columns
In total there are 810 different configurations tested. The best mean test score is 0.504 There are 1 configurations with this maximum score There are 121 configurations within 0.02 of this maximum score The describe of the configurations within 0.02 of this maximum score is as follows:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_RF__max_depth | param_RF__max_features | param_RF__min_samples_split | param_RF__n_estimators | split0_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 121.000000 | 121.000000 | 121.000000 | 121.000000 | 121.0 | 62.0 | 121 | 121.0 | 121.0 | 121.000000 | ... | 121.000000 | 121.000000 | 121.000000 | 121.000000 | 121.000000 | 121.000000 | 121.000000 | 121.000000 | 121.000000 | 121.000000 |
| unique | NaN | NaN | NaN | NaN | 1.0 | 1.0 | 5 | 3.0 | 5.0 | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| top | NaN | NaN | NaN | NaN | 10.0 | 10.0 | sqrt | 10.0 | 250.0 | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| freq | NaN | NaN | NaN | NaN | 121.0 | 62.0 | 25 | 41.0 | 30.0 | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| mean | 18.434562 | 0.086518 | 0.067312 | 0.002797 | NaN | NaN | NaN | NaN | NaN | 0.505976 | ... | 0.016065 | 60.983471 | 0.901184 | 0.895796 | 0.903739 | 0.907033 | 0.900444 | 0.897979 | 0.901029 | 0.004732 |
| std | 2.458859 | 0.052362 | 0.044733 | 0.002920 | NaN | NaN | NaN | NaN | NaN | 0.008143 | ... | 0.002831 | 35.078480 | 0.092733 | 0.097573 | 0.090266 | 0.087689 | 0.094143 | 0.096532 | 0.093068 | 0.004374 |
| min | 15.776477 | 0.025723 | 0.014662 | 0.000372 | NaN | NaN | NaN | NaN | NaN | 0.477581 | ... | 0.006824 | 1.000000 | 0.740138 | 0.746399 | 0.742070 | 0.715776 | 0.713481 | 0.718280 | 0.729357 | 0.000000 |
| 25% | 16.486251 | 0.059938 | 0.028991 | 0.000897 | NaN | NaN | NaN | NaN | NaN | 0.500521 | ... | 0.014393 | 31.000000 | 0.818618 | 0.807973 | 0.823664 | 0.831386 | 0.818030 | 0.812187 | 0.819652 | 0.000231 |
| 50% | 17.441281 | 0.079024 | 0.044653 | 0.001999 | NaN | NaN | NaN | NaN | NaN | 0.505735 | ... | 0.016201 | 61.000000 | 0.844500 | 0.836151 | 0.851210 | 0.854549 | 0.851210 | 0.841194 | 0.846121 | 0.004803 |
| 75% | 19.828854 | 0.100487 | 0.085306 | 0.003946 | NaN | NaN | NaN | NaN | NaN | 0.509906 | ... | 0.018151 | 91.000000 | 0.999583 | 0.999583 | 0.999374 | 0.999583 | 0.999791 | 0.999583 | 0.999443 | 0.008864 |
| max | 26.361274 | 0.554112 | 0.162782 | 0.024923 | NaN | NaN | NaN | NaN | NaN | 0.528676 | ... | 0.021420 | 121.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 0.013708 |
11 rows × 26 columns
####################################################
# GRID SEARCH RESULTS FOR SUBJECT E
####################################################
Best estimator has accuracy of 0.572 +- 0.0274 with the following parameters
{'CSP__n_components': 10, 'RF__max_depth': 10, 'RF__max_features': 'log2', 'RF__min_samples_split': 10, 'RF__n_estimators': 250}
Top 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_RF__max_depth | param_RF__max_features | param_RF__min_samples_split | param_RF__n_estimators | split0_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 748 | 17.974061 | 0.036909 | 0.069525 | 0.005292 | 10 | 10 | log2 | 10 | 250 | 0.551042 | ... | 0.027421 | 1 | 0.828851 | 0.824474 | 0.835105 | 0.805127 | 0.815340 | 0.821592 | 0.821748 | 0.009606 |
| 723 | 18.040514 | 0.027358 | 0.072144 | 0.006666 | 10 | 10 | sqrt | 2 | 250 | 0.550000 | ... | 0.026275 | 2 | 0.880967 | 0.872837 | 0.883886 | 0.857441 | 0.873697 | 0.871405 | 0.873372 | 0.008429 |
| 724 | 20.379906 | 0.075229 | 0.132791 | 0.002543 | 10 | 10 | sqrt | 2 | 500 | 0.545833 | ... | 0.023645 | 3 | 0.879716 | 0.869502 | 0.889097 | 0.858483 | 0.873906 | 0.871822 | 0.873754 | 0.009363 |
| 738 | 17.946493 | 0.067233 | 0.069855 | 0.002548 | 10 | 10 | log2 | 2 | 250 | 0.546875 | ... | 0.022907 | 4 | 0.889306 | 0.869502 | 0.885970 | 0.857024 | 0.871613 | 0.873072 | 0.874415 | 0.010736 |
| 599 | 19.761754 | 0.066341 | 0.140789 | 0.001572 | 10 | None | 0.2 | 10 | 500 | 0.546875 | ... | 0.023273 | 5 | 0.974567 | 0.973317 | 0.971857 | 0.975406 | 0.972905 | 0.973531 | 0.973597 | 0.001140 |
| 554 | 21.076969 | 0.053030 | 0.141088 | 0.004210 | 10 | None | sqrt | 10 | 500 | 0.557292 | ... | 0.022700 | 6 | 0.973108 | 0.975610 | 0.973108 | 0.976865 | 0.975406 | 0.975823 | 0.974987 | 0.001405 |
| 788 | 18.619474 | 0.056110 | 0.071977 | 0.001000 | 10 | 10 | 0.4 | 5 | 250 | 0.542708 | ... | 0.023626 | 7 | 0.873254 | 0.866375 | 0.882218 | 0.856607 | 0.863693 | 0.869321 | 0.868578 | 0.007963 |
| 589 | 20.146783 | 0.094316 | 0.151330 | 0.002314 | 10 | None | 0.2 | 2 | 500 | 0.553125 | ... | 0.020125 | 8 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 0.000000 |
| 782 | 16.848380 | 0.075064 | 0.035989 | 0.001154 | 10 | 10 | 0.4 | 2 | 100 | 0.548958 | ... | 0.023685 | 9 | 0.888680 | 0.878883 | 0.894309 | 0.866403 | 0.880158 | 0.881617 | 0.881675 | 0.008680 |
| 593 | 17.884179 | 0.078220 | 0.081308 | 0.000943 | 10 | None | 0.2 | 5 | 250 | 0.547917 | ... | 0.026656 | 10 | 0.999792 | 0.998958 | 0.999792 | 0.999375 | 1.000000 | 0.998541 | 0.999409 | 0.000516 |
10 rows × 26 columns
Worst 10 grid search results:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_RF__max_depth | param_RF__max_features | param_RF__min_samples_split | param_RF__n_estimators | split0_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 673 | 15.740538 | 0.035517 | 0.0 | 0.0 | 10 | 3 | None | 10 | 250 | NaN | ... | NaN | 801 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 482 | 15.650189 | 0.041667 | 0.0 | 0.0 | 6 | 10 | None | 2 | 100 | NaN | ... | NaN | 802 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 480 | 15.643253 | 0.045496 | 0.0 | 0.0 | 6 | 10 | None | 2 | 10 | NaN | ... | NaN | 803 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 403 | 15.743871 | 0.060389 | 0.0 | 0.0 | 6 | 3 | None | 10 | 250 | NaN | ... | NaN | 804 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 402 | 15.724227 | 0.056559 | 0.0 | 0.0 | 6 | 3 | None | 10 | 100 | NaN | ... | NaN | 805 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 401 | 15.632754 | 0.032020 | 0.0 | 0.0 | 6 | 3 | None | 10 | 50 | NaN | ... | NaN | 806 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 400 | 15.648134 | 0.054615 | 0.0 | 0.0 | 6 | 3 | None | 10 | 10 | NaN | ... | NaN | 807 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 399 | 15.752221 | 0.020398 | 0.0 | 0.0 | 6 | 3 | None | 5 | 500 | NaN | ... | NaN | 808 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 481 | 15.648571 | 0.037089 | 0.0 | 0.0 | 6 | 10 | None | 2 | 50 | NaN | ... | NaN | 809 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 404 | 15.799369 | 0.033787 | 0.0 | 0.0 | 6 | 3 | None | 10 | 500 | NaN | ... | NaN | 810 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
10 rows × 26 columns
In total there are 810 different configurations tested. The best mean test score is 0.572 There are 1 configurations with this maximum score There are 184 configurations within 0.02 of this maximum score The describe of the configurations within 0.02 of this maximum score is as follows:
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_CSP__n_components | param_RF__max_depth | param_RF__max_features | param_RF__min_samples_split | param_RF__n_estimators | split0_test_score | ... | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 184.000000 | 184.000000 | 184.000000 | 184.000000 | 184.0 | 117.0 | 184.0 | 184.0 | 184.0 | 184.000000 | ... | 184.000000 | 184.000000 | 184.000000 | 184.000000 | 184.000000 | 184.000000 | 184.000000 | 184.000000 | 184.000000 | 184.000000 |
| unique | NaN | NaN | NaN | NaN | 2.0 | 2.0 | 5.0 | 3.0 | 5.0 | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| top | NaN | NaN | NaN | NaN | 10.0 | 10.0 | 0.2 | 10.0 | 500.0 | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| freq | NaN | NaN | NaN | NaN | 121.0 | 115.0 | 39.0 | 64.0 | 48.0 | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| mean | 17.993735 | 0.058306 | 0.068691 | 0.003144 | NaN | NaN | NaN | NaN | NaN | 0.540138 | ... | 0.024112 | 92.483696 | 0.884980 | 0.877607 | 0.881343 | 0.873900 | 0.878506 | 0.878588 | 0.879154 | 0.005305 |
| std | 2.053023 | 0.027998 | 0.044409 | 0.002626 | NaN | NaN | NaN | NaN | NaN | 0.007999 | ... | 0.002554 | 53.273803 | 0.088171 | 0.093663 | 0.092499 | 0.095760 | 0.092634 | 0.092377 | 0.092383 | 0.003637 |
| min | 15.743215 | 0.016524 | 0.011330 | 0.000373 | NaN | NaN | NaN | NaN | NaN | 0.515625 | ... | 0.018583 | 1.000000 | 0.563894 | 0.564728 | 0.576819 | 0.566069 | 0.551897 | 0.572322 | 0.566545 | 0.000000 |
| 25% | 16.397151 | 0.041205 | 0.031448 | 0.001067 | NaN | NaN | NaN | NaN | NaN | 0.534375 | ... | 0.022676 | 46.750000 | 0.824630 | 0.812539 | 0.813112 | 0.802105 | 0.811119 | 0.812370 | 0.813515 | 0.001125 |
| 50% | 17.402916 | 0.053118 | 0.066000 | 0.002296 | NaN | NaN | NaN | NaN | NaN | 0.541667 | ... | 0.023931 | 92.500000 | 0.863456 | 0.855222 | 0.871795 | 0.847541 | 0.855773 | 0.855982 | 0.858260 | 0.006390 |
| 75% | 19.133591 | 0.068942 | 0.118619 | 0.004520 | NaN | NaN | NaN | NaN | NaN | 0.545833 | ... | 0.025917 | 138.250000 | 0.973108 | 0.976235 | 0.974255 | 0.977022 | 0.975511 | 0.975354 | 0.975065 | 0.008333 |
| max | 25.447294 | 0.232704 | 0.163992 | 0.014356 | NaN | NaN | NaN | NaN | NaN | 0.564583 | ... | 0.031857 | 184.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 0.010891 |
11 rows × 26 columns
####################################################
# TEST RESULTS FOR BEST FOUND GRID SEARCH
####################################################
# Configure global parameters for all experiments
subject_ids_to_test = ["B", "C", "E"] # Subjects with three recordings
start_offset = -1 # One second before visual queue
end_offset = 1 # One second after visual queue
baseline = (None, 0) # Baseline correction using data before the visual queue
filter_lower_bound = 2 # Filter out any frequency below this
filter_upper_bound = 32 # Filter out any frequency above this
best_found_csp_components = [10, 10 , 10]
best_found_rf_depth = [None, None, 10]
best_found_rf_max_featues = [0.2, "sqrt", "log2"]
best_found_rf_min_sample = [2, 10, 10]
best_found_rf_n_estimators = [500, 250, 250]
# Loop over all found results
for i in range(len(subject_ids_to_test)):
print("\n\n")
print("####################################################")
print(f"# TEST RESULTS FOR SUBJECT {subject_ids_to_test[i]}")
print("####################################################")
print("\n\n")
################# TRAINING DATA #################
with io.capture_output():
with io.capture_output():
# Determine the train subjects
train_subjects = copy.deepcopy(subject_ids_to_test)
train_subjects.remove(subject_ids_to_test[i])
mne_raws = []
# Get all training data
for train_subject in train_subjects:
mne_raws.extend(CLA_dataset.get_all_raw_mne_data_for_subject(subject_id= train_subject))
# Combine training data into singular mne raw
mne_raw = mne.concatenate_raws(mne_raws)
# Get epochs for that MNE raw
mne_epochs = CLA_dataset.get_usefull_epochs_from_raw(mne_raw,
start_offset= start_offset,
end_offset= end_offset,
baseline= baseline)
# Only keep epochs from the MI tasks
mne_epochs = mne_epochs['task/neutral', 'task/left', 'task/right']
# Load epochs into memory
mne_epochs.load_data()
# Get the labels
y_train = mne_epochs.events[:, -1]
# Use a fixed filter
mne_epochs.filter(l_freq= filter_lower_bound,
h_freq= filter_upper_bound,
picks= "all",
phase= "minimum",
fir_window= "blackman",
fir_design= "firwin",
pad= 'median',
n_jobs= -1,
verbose= False)
# Get a half second window
X_train = mne_epochs.get_data(tmin= 0.1, tmax= 0.6)
# Delete resedual vars for training data
del mne_raws
del mne_raw
del mne_epochs
################# TESTING DATA #################
with io.capture_output():
# Get test data
mne_raws = CLA_dataset.get_all_raw_mne_data_for_subject(subject_id= subject_ids_to_test[i])
# Combine test data into singular mne raw
mne_raw = mne.concatenate_raws(mne_raws)
# Get epochs for test MNE raw
mne_epochs = CLA_dataset.get_usefull_epochs_from_raw(mne_raw,
start_offset= start_offset,
end_offset= end_offset,
baseline= baseline)
# Only keep epochs from the MI tasks
mne_epochs = mne_epochs['task/neutral', 'task/left', 'task/right']
# Load epochs into memory
mne_epochs.load_data()
# Get the labels
y_test = mne_epochs.events[:, -1]
# Use a fixed filter
mne_epochs.filter(l_freq= filter_lower_bound,
h_freq= filter_upper_bound,
picks= "all",
phase= "minimum",
fir_window= "blackman",
fir_design= "firwin",
pad= 'median',
n_jobs= -1,
verbose= False)
# Get a half second window
X_test = mne_epochs.get_data(tmin= 0.1, tmax= 0.6)
# Delete resedual vars for training data
del mne_raw
del mne_epochs
del mne_raws
################# FIT AND PREDICT #################
# Make the classifier
csp = CSP(norm_trace=False,
component_order="mutual_info",
cov_est= "epoch",
n_components= best_found_csp_components[i])
rf = RandomForestClassifier(bootstrap= True,
criterion= "gini",
max_depth= best_found_rf_depth[i],
max_features= best_found_rf_max_featues[i],
min_samples_split= best_found_rf_min_sample[i],
n_estimators= best_found_rf_n_estimators[i])
# Configure the pipeline
pipeline = Pipeline([('CSP', csp), ('RF', rf)])
# Fit the pipeline
with io.capture_output():
pipeline.fit(X_train, y_train)
# Get accuracy for single fit
y_pred = pipeline.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
# Print accuracy results and CM
print(f"Test accuracy for subject {subject_ids_to_test[i]}: {np.round(accuracy, 4)}")
ConfusionMatrixDisplay.from_predictions(y_true= y_test, y_pred= y_pred)
plt.show()
# plot CSP patterns estimated on train data for visualization
pipeline['CSP'].plot_patterns(CLA_dataset.get_last_raw_mne_data_for_subject(subject_id= subject_ids_to_test[i]).info, ch_type='eeg', units='Patterns (AU)', size=1.5)
plt.show()
# Remove unsused variables
del subject_ids_to_test
del best_found_csp_components
del best_found_rf_depth
del best_found_rf_max_featues
del best_found_rf_min_sample
del best_found_rf_n_estimators
del i
del X_test
del y_test
del X_train
del y_train
del csp
del rf
del train_subjects
del train_subject
del pipeline
del y_pred
del accuracy
del start_offset
del end_offset
del baseline
del filter_lower_bound
del filter_upper_bound
#################################################### # TEST RESULTS FOR SUBJECT B #################################################### Test accuracy for subject B: 0.3909
Reading 0 ... 667799 = 0.000 ... 3338.995 secs...
#################################################### # TEST RESULTS FOR SUBJECT C #################################################### Test accuracy for subject C: 0.4453
Reading 0 ... 669399 = 0.000 ... 3346.995 secs...
#################################################### # TEST RESULTS FOR SUBJECT E #################################################### Test accuracy for subject E: 0.3708
Reading 0 ... 666999 = 0.000 ... 3334.995 secs...